我想以下列方式從 TF 資料集中生成批處理樣本:每批將由兩個具有相同“標簽”特征的樣本組成。在 TensorFlow 中實作這一目標的最有效方法是什么?
uj5u.com熱心網友回復:
假設您有某種具有多個標簽的資料,如下所示:
x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
為了實作你的目標。“每個批次將由兩個具有相同‘標簽’特征的樣本組成”,您可以使用該filter函式為每個標簽創建單獨的資料集,然后使用API的concatenate函式將這些資料集合并為一個:tf.data.Dataset
x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
batch_size = 2
dataset0 = dataset.filter(lambda x, y: tf.equal(y, 0)).batch(batch_size)
dataset1 = dataset.filter(lambda x, y: tf.equal(y, 1)).batch(batch_size)
dataset2 = dataset.filter(lambda x, y: tf.equal(y, 2)).batch(batch_size)
dataset3 = dataset.filter(lambda x, y: tf.equal(y, 3)).batch(batch_size)
dataset = dataset0.concatenate(dataset1).concatenate(dataset2).concatenate(dataset3).shuffle(buffer_size=20)
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/415559.html
標籤:
上一篇:如何在tf.function(圖形模式)中展平梯度(張量串列)
下一篇:如何設定WPF三態按鈕的顏色
