我正在嘗試使用 tensorflow 的tf.map_fn來映射一個參差不齊的張量,但是我遇到了一個無法修復的錯誤。這是一些演示錯誤的最小代碼:
import tensorflow as tf
X = tf.ragged.constant([[0,1,2], [0,1]])
def outer_product(x):
return x[...,None]*x[None,...]
tf.map_fn(outer_product, X)
我想要的輸出是:
tf.ragged.constant([
[[0, 0, 0],
[0, 1, 2],
[0, 2, 4]],
[[0, 0],
[0, 1]]
])
我得到的錯誤是:
“InvalidArgumentError:所有 flat_values 必須具有兼容的形狀。索引 0 處的形狀:[3]。索引 1 處的形狀:[2]。如果您使用 tf.map_fn,那么您可能需要指定具有適當 ragged_rank 的顯式 fn_output_signature,以及/ 或將輸出張量轉換為 RaggedTensors。[Op:RaggedTensorFromVariant]"
我意識到我需要指定 fn_output_signature 但盡管進行了實驗,但我無法弄清楚它應該是什么。
編輯:我稍微整理了 AloneTogether 的優秀答案。他的回答使用該tf.ragged.stack函式將張量轉換tf.map_fn為由于某種原因需要的參差不齊的張量
X = tf.ragged.constant([
[0,1,2],
[0,1]
])
def outer_product(x):
t = x[...,None] * x[None,...]
return tf.ragged.stack(t)
y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
dtype=tf.type_spec_from_value(X).dtype,
ragged_rank=2))
print(y.shape) # == (2, 1, None , None)
y = tf.squeeze(y, 1)
tf.print(y.shape) # == (2, None , None)
uj5u.com熱心網友回復:
參差不齊的張量有時真的很棘手。這是您可以嘗試的一種選擇:
import tensorflow as tf
X = tf.ragged.constant([
[0,1,2],
[0,1]
])
def outer_product(x):
t = x[...,None] * x[None,...]
return tf.ragged.stack(t)
y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
dtype=tf.type_spec_from_value(X).dtype,
ragged_rank=2,
row_splits_dtype=tf.type_spec_from_value(X).row_splits_dtype))
tf.print(y)
y = tf.concat([y[0, :], y[1, :]], axis=0) # Remove additional dimension from Ragged Tensor
tf.print(y)
[
[
[
[0, 0, 0],
[0, 1, 2],
[0, 2, 4]
]
],
[
[
[0, 0],
[0, 1]
]
]
]
并在洗掉附加維度后tf.concat:
[
[
[0, 0, 0],
[0, 1, 2],
[0, 2, 4]
],
[
[0, 0],
[0, 1]
]
]
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/382721.html
