我有一個從存盤庫加載的 TensorFlow 模型
model = tf.saved_model.load(folder)
我的目標是在 Jax 中復制相同的模型,因此我需要了解加載的變數值(權重和偏差)是否正確。
我可以恢復變數值的一種方法i就是
vars = model.variables
print(vars[i].numpy())
但是,如果我assign將這些值放入 Jax 網路,我不會恢復正確的結果,因此為了除錯,我試圖分析特定層的輸出。為此,我需要確保權重和偏差相同,例如通過預先分配它們。具體來說,如果我這樣做
numpy_vars = [v.numpy() for v in vars] # This is done in eager mode.
with tf.compat.v1.Session(graph = graph) as sess:
tvars = tf.compat.v1.trainable_variables()
tf.compat.v1.variables_initializer(vars).run() #Necessary init. of either tvars/vars
for v, tv in zip(numpy_vars, tvars):
tv.assign(v)
print(tvars[0].eval()) # This returns the value of the variable in graph mode.
print('------------------------------')
print(numpy_vars[0])
它似乎沒有回傳我預期的相同值,盡管它們具有相同的形狀。我想知道這是否可能是因為 中有初始化操作model.graph,但我不太確定。如果我改為更改行
tv.assign(v)
和
sess.run(tv.assign(v))
我收到錯誤
TypeError: Argument `fetch` = <tf.Variable 'UnreadVariable' shape=(11, 256) dtype=float32> has invalid type "_UnreadVariable" must be a string or Tensor. (Can not convert a _UnreadVariable into a Tensor or Operation.)
關于如何分配這些變數的值以便它們在圖形執行期間保持固定的任何建議?
uj5u.com熱心網友回復:
答案似乎是這樣的:
numpy_vars = [v.numpy() for v in vars]
with tf.compat.v1.Session(graph = graph) as sess:
tvars = tf.compat.v1.trainable_variables()
tf.compat.v1.variables_initializer(vars).run()
print(tvars[0].eval())
print('------------------------------')
for v, tv in zip(numpy_vars, tvars):
tf.compat.v1.assign(tv, v).read_value().eval()
print(tvars[0].eval())
print('------------------------------')
print(numpy_vars[0])
行后
tf.compat.v1.assign(tv, v).read_value().eval()
我已經檢查過權重和偏差是否正常作業。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/533771.html
上一篇:Nginx優化與防盜鏈
