看看下面的代碼示例:
def myFun(my_tensor):
#The following line works
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
#The following line leads to error
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[p]]), tf.constant([1]))
我用一個簡單的案例來描述我面臨的問題 這個函式 (myFun) 被稱為 tf.while_loop 的主體(如果相關) my_tensor 的定義
my_tensor = tf.zeros(5, tf.int32)
如何定義 tf.tensor_scatter_update 的索引引數?我正在使用 tensorflow1.15
uj5u.com熱心網友回復:
您不能使用張量p作為 的引數tf.constant。也許嘗試這樣的事情:
%tensorflow_version 1.x
import tensorflow as tf
def myFun(my_tensor):
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
new_tensor= tf.tensor_scatter_update(my_tensor, [[p]], tf.constant([1]))
with tf.Session() as sess:
p_value = p.eval()
tensor_values = my_tensor.eval()
new_tensor_values = new_tensor.eval()
print('p -->', p_value)
print('my_tensor -->', tensor_values)
print('new_tensor -->', new_tensor_values)
my_tensor = tf.zeros(5, tf.int32)
myFun(my_tensor)
p --> 1
my_tensor --> [1 0 0 0 0]
new_tensor --> [1 1 0 0 0]
您還可以環繞p一個tf.Variable:
def myFun(my_tensor):
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
indices = tf.Variable([[p]])
new_tensor= tf.tensor_scatter_update(my_tensor, indices, tf.constant([1]))
with tf.Session() as sess:
sess.run(indices.initializer)
p_value = p.eval()
tensor_values = my_tensor.eval()
new_tensor_values = new_tensor.eval()
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/372842.html
上一篇:轉換層,沒有為任何變數提供梯度
下一篇:如何在VGG16中更改批量大小?
