我正在嘗試使用 Python 在 TensorFlow 中構建一個 CNN。我已將影像加載到資料集中,如下所示:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
"train_data", shuffle=True, image_size=(578, 260),
batch_size=BATCH_SIZE)
但是,如果我想在這個資料集上使用 train_test_split 或 fit_resample ,我需要將它分成資料和標簽。我是 TensorFlow 的新手,不知道該怎么做。真的很感激任何幫助。
uj5u.com熱心網友回復:
您可以使用該subset引數將資料分隔為training和validation。
import tensorflow as tf
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
image_size=(256, 256),
seed=1,
batch_size=32)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=1,
image_size=(256, 256),
batch_size=32)
for x, y in train_ds.take(1):
print('Image --> ', x.shape, 'Label --> ', y.shape)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Image --> (32, 256, 256, 3) Label --> (32,)
至于你的標簽,根據檔案:
“推斷”(標簽從目錄結構生成)、無(無標簽)或與目錄中找到的影像檔案數量相同大小的整數標簽串列/元組。標簽應根據影像檔案路徑的字母數字順序進行排序(通過 Python 中的 os.walk(directory) 獲得)。
所以只需嘗試迭代train_ds,看看它們是否在那里。您還可以使用引數label_mode來參考您擁有的標簽型別并class_names明確列出您的類。
如果你的類inbalanced,你可以使用class_weights的引數model.fit(*)。有關更多資訊,請查看此帖子。
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/350056.html
