我想在 tf.function 中的張量上使用 for 回圈,如下所示:
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
print(i)
我定義:
S = tf.random.uniform([2,2],0,1)
然后
test(S)
給
Tensor("while/Placeholder:0", shape=(), dtype=int32)
盡管
for i in range(tf.shape(S)[0]):
print(i)
回傳
0
1
為什么我不能遍歷 tf.function 中張量的長度?
uj5u.com熱心網友回復:
使用tf.print:
import tensorflow as tf
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
tf.print(i)
S = tf.random.uniform([2,2],0,1)
test(S)
# 0
# 1
tf.function 請在此處檢查使用 python 操作的副作用。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/520893.html
