我想使用具有特定類索引的資料集管道。
- 例如:
如果我使用 CIFAR-10 資料集。我可以按如下方式加載資料集:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
加載所有類標簽(10個類)。我可以使用以下代碼創建管道:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test,y_test)).batch(64)
這適用于訓練 Keras 模型。
- 現在我想用幾個樣本創建一個管道(而不是使用所有 10 個類樣本,可能只使用 5 個樣本)。有沒有辦法制作這樣的管道?
uj5u.com熱心網友回復:
您可以使用tf.data.Dataset.filter:
import tensorflow as tf
class_indexes_to_keep = tf.constant([0, 3, 4, 6, 8], dtype=tf.int64)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.astype(int)
y_test = y_test.astype(int)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)
要轉換為分類標簽,您可以嘗試:
import tensorflow as tf
one_hot_encode = tf.keras.utils.to_categorical(tf.range(10, dtype=tf.int64), num_classes=10)
class_indexes_to_keep = tf.gather(one_hot_encode, tf.constant([0, 3, 4, 6, 8], dtype=tf.int64))
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.astype(int)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).map(lambda x, y: (x, tf.one_hot(y, 10)[0]))
train_dataset = train_dataset.filter(lambda x, y: tf.reduce_any(tf.reduce_all(y == class_indexes_to_keep, axis=-1))).batch(64)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/511628.html
上一篇:從圖表中提取繪圖線
下一篇:在R中繪制影像堆疊的優雅方式
