我想將 TensorFlow 管道一分為二,并使用tf.data.Dataset.map().
像這樣:
dataset = tf.data.Dataset.from_tensor_slices(list(range(20)))
dataset = dataset.shuffle(20).batch(10)
dataset_1 = dataset.map(lambda x: x)
dataset_2 = dataset.map(lambda x: x 1)
for d1, d2 in zip(dataset_1, dataset_2):
print(d1.numpy()) # [13 14 12 15 18 2 16 19 6 4]
print(d2.numpy()) # [18 16 6 7 3 15 17 9 2 4]
break
但是,這不是我想要的輸出。我的期望是什么時候d1是[13 14 12 15 18 2 16 19 6 4],d2應該是[14 15 13 16 19 3 17 20 7 5]。我想我知道發生了什么,但不知道如何寫。我不想從一開始就創建兩個管道(因為開銷很大)。你能給我一些建議嗎?
謝謝閱讀。
更新
我決定如下實作它。
# use the same seed for dataset_1 and dataset_2
dataset_1 = dataset.shuffle(20, seed=0).batch(10)
dataset_2 = dataset.shuffle(20, seed=0).batch(10)
dataset_1 = dataset_1.map(lambda x: x)
dataset_2 = dataset_2.map(lambda x: x 1)
uj5u.com熱心網友回復:
tensorflow shuffle 函式的默認行為是每次呼叫 .numpy() 時重新洗牌,為了防止這種情況,您需要設定 reshuffle_each_itertaion=False ( https://www.tensorflow.org/api_docs/python/tf/data/Dataset #洗牌)。
dataset = tf.data.Dataset.from_tensor_slices(list(range(20)))
dataset = dataset.shuffle(20, reshuffle_each_iteration=False).batch(10)
dataset_1 = dataset.map(lambda x: x)
dataset_2 = dataset.map(lambda x: x 1)
for d1, d2 in zip(dataset_1, dataset_2):
print(d1.numpy()) # [10 13 3 19 12 16 7 11 2 8]
print(d2.numpy()) # [11 14 4 20 13 17 8 12 3 9]
break
但這樣做的后果是,如果您嘗試第二次呼叫 d1.numpy() 或 d2.numpy() 值將保持不變。
uj5u.com熱心網友回復:
兩個動作的簡單堆疊怎么樣
dataset = tf.data.Dataset.from_tensor_slices(list(range(20)))
dataset = dataset.shuffle(20)
def func1(x):
return x
def func2(x):
return x 1
dataset = dataset.map(lambda sample: tf.stack([func1(sample), func2(sample)], axis=0))
list(dataset.as_numpy_iterator())
# [array([ 9, 10], dtype=int32),
# array([16, 17], dtype=int32),
# array([10, 11], dtype=int32),
# array([1, 2], dtype=int32),
# array([11, 12], dtype=int32),
# array([6, 7], dtype=int32),
# array([18, 19], dtype=int32),
# array([3, 4], dtype=int32),
# array([8, 9], dtype=int32),
# array([15, 16], dtype=int32),
# array([4, 5], dtype=int32),
# array([14, 15], dtype=int32),
# array([0, 1], dtype=int32),
# array([12, 13], dtype=int32),
# array([17, 18], dtype=int32),
# array([2, 3], dtype=int32),
# array([5, 6], dtype=int32),
# array([13, 14], dtype=int32),
# array([7, 8], dtype=int32),
# array([19, 20], dtype=int32)]
之后,您可以
根據需要取消
dataset = dataset.unbatch()
批處理和批處理
dataset = dataset.batch(10)。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/382726.html
上一篇:從.txt中洗掉引號
