tf 函式不會改變物件的屬性
class f:
v = 7
def __call__(self):
self.v = self.v 1
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
預期列印:7 8 8 9
而是: 7 8 7 8
當我洗掉 @tf.function 裝飾器時,一切都按預期作業。如何使用@tf.function 使我的函式按預期作業
uj5u.com熱心網友回復:
此行為記錄在此處:
副作用,如列印、附加到串列和改變全域變數,可能會在函式內部出現意外行為,有時會執行兩次或全部執行。它們僅在您第一次使用一組輸入呼叫函式時發生。之后,被跟蹤的 tf.Graph 被重新執行,而不執行 Python 代碼。一般的經驗法則是避免在您的邏輯中依賴 Python 副作用,只使用它們來除錯您的跟蹤。否則,諸如 tf.data、tf.print、tf.summary、tf.Variable.assign 和 tf.TensorArray 之類的 TensorFlow API 是確保您的代碼在每次呼叫時都由 TensorFlow 運行時執行的最佳方式。
因此,也許嘗試使用tf.Variable以查看預期的更改:
import tensorflow as tf
class f:
v = tf.Variable(7)
def __call__(self):
self.v.assign_add(1)
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
7
8
8
9
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/369683.html
下一篇:TF,訪問資料集物件中元素的檔案名:dataset.group_by_windowgroupbyfilename?
