在我的 TF 模型中,我的call函式呼叫外部能量函式,該函式依賴于單個引數傳遞兩次的函式(參見下面的簡化版本):
import tensorflow as tf
@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
E3 = 2.0
return E3
@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
E3 = 2.0
return E3
@tf.function # without tf.function this works fine
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
# E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
return E3
class SWLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.gamma = tf.Variable(2.51412, dtype=tf.float32)
def call(self, coords_all):
total_conf_energy = energy( coords_all, self.gamma)
return total_conf_energy
# =============================================================================
SWL = SWLayer()
coords2 = tf.constant([[
1.9434, 1.0817, 1.0803,
2.6852, 2.7203, 1.0802,
1.3807, 1.3573, 1.3307]])
with tf.GradientTape() as tape:
tape.watch(coords2)
E = SWL( coords2)
在這里,如果 gamma 只傳遞一次,或者如果我不使用tf.function裝飾器。但是使用tf.function并傳遞相同的變數兩次,我收到以下錯誤:
Traceback (most recent call last):
File "temp_tf.py", line 47, in <module>
E = SWL( coords2)
File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "temp_tf.py", line 34, in call
total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).
in user code:
File "temp_tf.py", line 22, in energy *
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
IndexError: list index out of range
Call arguments received:
? coords_all=tf.Tensor(shape=(1, 9), dtype=float32)
這是預期的行為嗎?
uj5u.com熱心網友回復:
有趣的問題!我認為錯誤源于回溯,這導致 tf.function 多次評估 python 片段energy。看到這個問題。此外,這可能與錯誤有關。
幾個觀察:
1. 從calc_sw3作品中移除 tf.function 裝飾器并與檔案一致:
[...] tf.function 適用于一個函式和它呼叫的所有其他函式。
因此,如果您再次tf.function明確申請calc_sw3,您可能會觸發回溯,但您可能想知道為什么calc_sw3_noerr有效?也就是說,它必須與變數有關gamma。
2. 將輸入簽名添加到energy函式上方的 tf.function 中,同時保留其余代碼的原樣,也可以:
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
E3 = calc_sw3(gamma, gamma, norm_rij)
return E3
這種方法:
[...] 確保只創建一個 ConcreteFunction,并將 GenericFunction 限制為指定的形狀和型別。當張量具有動態形狀時,這是限制回溯的有效方法。
所以也許假設gamma每次都以不同的形狀呼叫,從而觸發回溯(只是一個假設)。一個錯誤被觸發的事實是后來居然故意或刻意設計成說這里。還有另一個有趣的評論:
tf.functions 只能處理預定義的輸入形狀,如果形狀發生變化,或者如果傳遞了不同的 python 物件,tensorflow 會自動重建函式
最后,為什么我認為這是一個跟蹤問題?因為實際錯誤來自代碼片段的這一部分:
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
您可以通過將其注釋掉并替換norm_rij為某個值然后呼叫calc_sw3. 它會起作用。這意味著此代碼片段可能會執行多次,可能是由于上述原因。這在這里也有很好的記錄:
在第一階段,稱為“跟蹤”,Function 創建一個新的 tf.Graph。Python 代碼運行正常,但所有 TensorFlow 操作(如添加兩個 Tensor)都被推遲:它們被 tf.Graph 捕獲而不運行。
在第二階段,運行包含在第一階段延遲的所有內容的 tf.Graph。這個階段比追蹤階段快得多
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/391938.html
