我想獲得最大值的索引。
例如:
[
[
[0.1 0.3 0.6],
[0.0 0.4 0.1]
],
[
[0.9 0.2 0.6],
[0.8 0.1 0.5]
]
]
我想得到[[0,0,2], [0,1,1], [1,0,0], [1,1,0]]。如何在 Tensorflow 中以最簡單的方式做到這一點?
uj5u.com熱心網友回復:
可以利用TF在最后一維的廣播
a = tf.constant([[[0.1, 0.3, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
tf.where(a == b)
輸出
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]], dtype=int64)>
如果每行有多個最大值并且您只想保留第一個索引,您可以得出結果中每行對應的段,然后執行 asegment_min以獲取每個段中的第一個索引。
a = tf.constant([[[0.1, 0.6, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
c = tf.cast(tf.where(a == b), tf.int32)
d = tf.reduce_sum(tf.math.cumprod(a.shape[:-1], reverse=True, exclusive=True) * c[:,:-1], axis=1)
tf.math.segment_min(c,d)
輸出
<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[0, 0, 1],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]])>
uj5u.com熱心網友回復:
#argmax will give the index but not in the format you want
max_index = tf.reshape(tf.math.argmax(a, -1),(-1, 1))
max_index
<tf.Tensor: shape=(4, 1), dtype=int64, numpy=
array([[2],
[1],
[0],
[0]])>
#Format output
idx_axis =tf.reshape(tf.Variable(np.indices((a.shape[0],a.shape[1])).transpose(1,2,0)), (-1,a.shape[1]))
idx_axis
<tf.Tensor: shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
[0, 1],
[1, 0],
[1, 1]])>
tf.concat([idx_axis,max_index], axis=1)
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]])>
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/520897.html
標籤:Python张量流
上一篇:避免在ML編輯-編譯-運行回圈中重新加載權重/資料集
下一篇:Javascript多輸入名稱值
