我想使用 Python 中的 Tensorflow 創建一個可以同時處理浮點數和向量作為輸入的函式。我定義了以下函式:
def g(t):
if tf.rank(t) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t),1)
但是,我想在另一個 tf.function 中呼叫該函式。作為測驗,我做了以下功能:
@tf.function
def Test(t):
return g(t)
呼叫 g(0.5) 給出
Rank=0
Out[218]: <tf.Tensor: shape=(), dtype=float32, numpy=2.7182817>
呼叫 Test(0.5) 給出:
rank=0
rank=higher
Traceback (most recent call last):
Input In [219] in <cell line: 1>
Test(0.5)
File ~\Anaconda3\lib\site-packages\tensorflow\python\util\traceback_utils.py:153 in error_handler
raise e.with_traceback(filtered_tb) from None
File ~\AppData\Local\Temp\__autograph_generated_filegb02ol08.py:12 in tf__Test
retval_ = ag__.converted_call(ag__.ld(gn), (ag__.ld(t),), None, fscope)
File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:37 in tf__gn
ag__.if_stmt(ag__.converted_call(ag__.ld(int), (ag__.converted_call(ag__.ld(tf).rank, (ag__.ld(t),), None, fscope),), None, fscope) == 0, if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:33 in else_body
retval_ = ag__.ld(V0) ag__.ld(labda) * ag__.ld(theta) * ag__.converted_call(ag__.ld(tf).math.reduce_sum, (ag__.ld(c) / ag__.ld(gamma) * (1 - ag__.converted_call(ag__.ld(tf).math.exp, (-ag__.ld(gamma) * ag__.ld(t),), None, fscope)), 1), None, fscope)
ValueError: in user code:
File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 11, in Test *
return gn(t)
File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 7, in gn *
return V0 labda * theta * tf.math.reduce_sum(c / gamma * (1 - tf.math.exp(-gamma * t)),1)
ValueError: Invalid reduction dimension 1 for input with 1 dimensions. for '{{node cond/Sum}} = Sum[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=false](cond/mul_1, cond/Sum/reduction_indices)' with input shapes: [1], [] and with computed input tensors: input[1] = <1>.
為什么 if-else 陳述句的兩個引數都在 tf.function 中被呼叫?以及如何使函式 g 在 tf.function 中作業?
uj5u.com熱心網友回復:
看起來有人在最近的Github Issue中提出了這種行為。在關閉問題之前突出顯示一位 Tensorflow 開發人員的回應:
The cause of this problem is due to the behavior of condition tracing in TensorFlow: the same input is applied to both true and false sides for graph tracing, when the condition is based on a non-static value (i.e. tf.rank(v) == 2).
有兩種可行的解決方案。
使用常數值
如果您使用tf.get_static_value(details here ) 回傳由 回傳的 0-D 張量的常量值tf.rank,它會阻止條件跟蹤,因為它會根據形狀評估張量(將其轉換為 int、float、numpy 陣列等)和型別)。
def g(t):
if tf.get_static_value(tf.rank(t)) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)
這將回傳預期結果:
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
直接形狀評估
與其使用 ,不如tf.rank直接評估形狀,這還需要將任何非張量輸入轉換為張量:
def g(t):
if not isinstance(t, tf.Tensor):
t = tf.convert_to_tensor(t)
if t.shape.ndims == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)
這個實作也產生了預期的結果:
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/517982.html
標籤:Python功能张量流if 语句
