我有 tensorflow 張量,其概率如下:
>>> valid_4_preds
array([[0.9817431 , 0.01259811, 0.50729334, 0.00053732, 0.6966804 ,
0.00488825],
[0.9851129 , 0.01246135, 0.38177294, 0.00378728, 0.8398497 ,
0.68413687],
[0.00061161, 0.00005008, 0.00017785, 0.0000152 , 0.00017121,
0.00002404],
[0.9991425 , 0.23962161, 0.98579687, 0.01727398, 0.9354003 ,
0.3325037 ]], dtype=float32)
我現在需要將具有不同閾值的上述概率映射到classes(或帶有文本的張量)并獲取它們。
>>> # printing classes
>>> classes
<tf.Tensor: shape=(6,), dtype=string, numpy=
array([b'class_1', b'class_2', b'class_3', b'class_4', b'class_5',
b'class_6'], dtype=object)>
>>> # converting to bools
>>> true_falses = tf.math.greater_equal(valid_4_preds, tf.constant([0.5, 0.40, 0.20, 0.80, 0.5, 0.4]))
>>> true_falses
<tf.Tensor: shape=(4, 6), dtype=bool, numpy=
array([[ True, False, True, False, True, False],
[ True, False, True, False, True, True],
[False, False, False, False, False, False],
[ True, False, True, False, True, False]])>
現在,我試圖在true_falses有Trues 的索引處獲取文本(這是我的預期輸出),如下所示:
>>> <some-tensorflow-operations>
<tf.Tensor: shape=(4, 6), dtype=bool, numpy=
array([['class_1', 'class_3', 'class_5'],
['class_1', 'class_3', 'class_5', 'class_6'],
[],
['class_1', 'class_3', 'class_5']])>
這是我嘗試過的:
tf.boolean_mask似乎解決了這個目的,但mask它需要嚴格地是一維陣列。tf.where可用于獲取索引,其輸出可在整形為單維后傳遞給以tf.gather獲取相應的類,如下所示:
>>> tf.gather(classes, tf.reshape(tf.where(true_falses[0] == True), shape=(-1,)))
<tf.Tensor: shape=(3,), dtype=string, numpy=array([b'class_1', b'class_3', b'class_5'], dtype=object)>
但是,我一直無法弄清楚如何在二維陣列上執行此操作。
這個邏輯將進入一個通過 tensorflow-serving 服務的簽名,所以操作只需要嚴格地是 tensorflow。如何在 2D 張量或陣列上執行此操作?更有效和更快的操作將不勝感激。
uj5u.com熱心網友回復:
tf.ragged.boolean_mask?
import tensorflow as tf
classes = tf.constant([b'class_1', b'class_2', b'class_3', b'class_4', b'class_5', b'class_6'])
true_falses = tf.constant([
[ True, False, True, False, True, False],
[ True, False, True, False, True, True],
[False, False, False, False, False, False],
[ True, False, True, False, True, False]]
)
tf.ragged.boolean_mask(
data=tf.tile(tf.expand_dims(classes, 0), [tf.shape(true_falses)[0], 1]),
mask=true_falses
)
# <tf.RaggedTensor [[b'class_1', b'class_3', b'class_5'], [b'class_1', b'class_3', b'class_5', b'class_6'], [], [b'class_1', b'class_3', b'class_5']]>
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/402825.html
標籤:
