我的代碼:
def entropy(x):
return tf.convert_to_tensor(skimage.measure_shannon_entropy(np.array(x)))
def calc_entropy(x, fn):
i = tf.constant(0)
while_condition = lambda i: tf.less(i, fn)
#loop
r = tf.while_loop(while_condition, entropy, x[0, :, :, i])
return r
a = tf.constant([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) # shape=(1, 2, 2, 2)
output = calc_entropy(a, 2)
輸出:[1.23, 0.12]
但是我的代碼顯示了這個錯誤:ValueError:具有多個元素的陣列的真值是不明確的。使用 a.any() 或 a.all()
uj5u.com熱心網友回復:
嘗試這樣的事情:
import tensorflow as tf
from skimage.measure.entropy import shannon_entropy
import numpy as np
def entropy(i, v, x):
v = tf.tensor_scatter_nd_update(v, [[i]], [tf.convert_to_tensor(shannon_entropy(np.array(x[0, :, :, i])))])
return tf.add(i, 1), v, x
def calc_entropy(x, fn):
i = tf.constant(0)
v = tf.zeros((fn,), dtype=tf.float64)
while_condition = lambda i, v, x: tf.less(i, fn)
_, v, _ = tf.while_loop(while_condition, entropy, loop_vars=(i, v, x))
return v
a = tf.constant([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
output = calc_entropy(a, 2)
print(output)
tf.Tensor([2. 2.], shape=(2,), dtype=float64)
但是,我不知道您對輸出的期望如何[1.23, 0.12]。手動檢查計算:
a = np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
print(shannon_entropy(a[0, :, :, 0]))
print(shannon_entropy(a[0, :, :, 1]))
2.0
2.0
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/432299.html
