我正在訓練一個神經網路來讀取 4 個影像并預測接下來的 4 個影像。由于我有一個巨大的資料集,我撰寫了這個生成器:
def my_gen(array_passed):
for i in range(len(array_passed)-7):
x = arr_train_ids[i:i 4]
x = obtain_data(x) #shape (4,480,480,1)
y = arr_train_ids[i 4:i 8]
y = obtain_data(y) #same shape as x
yield x, y
在哪里obtain_data打開檔案,加載 4 個 480x480 矩陣并重塑為 (4, 480, 480, 1)。然后,按照這里的帖子,我這樣做:
def my_input_fn(array_passed):
dataset = tf.data.Dataset.from_generator(lambda: my_gen(array_passed),
output_types=(tf.float32, tf.float32),
) #using TF2.3, output_signature not valid
dataset = dataset.batch(1)
#dataset = dataset.repeat(10)
return dataset
train_dataset = my_input_fn(arr_train_ids)
val_dataset = my_input_fn(arr_val_ids)
test_dataset = my_input_fn(arr_test_ids)
我不確定使用的必要性repeat...現在在我的代碼中進行了注釋。
模型,稱為rnc,放在這里(為簡潔起見省略,涉及 ConvLSTM 和 Conv3D)。
model = rnc(input_shape=(4, 480, 480, 1))
model.compile(optimizer=tf.keras.optimizers.Adam(lr=3e-4),loss='mae')
history = model.fit(train_dataset, validation_data=val_dataset, epochs=100)
我想訓練 1 批資料。理論上,1批資料應該是4個480x480矩陣,1個通道,(4, 480, 480, 1)作為輸入,(4, 480, 480, 1)作為輸出。
但是我沒有從培訓中獲得任何東西,我很懷疑。我想知道 TF 資料集的準備是否是我做的正確的事情。
uj5u.com熱心網友回復:
如果您檢查了結果input_batch,并output_batch從您的train_dataset,它必須作業。
在我的簡單生殖示例中,
import tensorflow as tf
import tensorflow.keras as keras
x = tf.random.normal((100, 4, 480, 480, 1)) # Say I have 100 data.
y = tf.random.normal((100, 4, 480, 480, 1))
train_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_ds = train_ds.batch(batch_size=32, drop_remainder=True)
# check whether this dataset really produces the things I want
sample_input_batch, sample_output_batch = next(iter(train_ds))
print(sample_input_batch.shape) # (32, 4, 480, 480, 1)
print(sample_output_batch.shape) # (32, 4, 480, 480, 1)
simple_model = keras.Sequential([
keras.layers.InputLayer(input_shape=(4, 480, 480, 1)),
keras.layers.Dense(10, activation='relu'),
keras.layers.Dense(1, activation='relu')
])
simple_model.compile(loss=keras.losses.MeanSquaredError(),
optimizer=keras.optimizers.Adam(),
metrics=['mse'])
# check whether the model really produces the thing I expect
sample_predicted_batch = simple_model(sample_input_batch)
print(sample_predicted_batch.shape) # (32, 4, 480, 480, 1)
simple_model.fit(train_ds)
# 3/3 [==============================] - 0s 24ms/step - loss: 1.1743 - mse: 1.1743
# Then it should work!
此外,如果你有一個巨大的資料集,你真的不需要repeat()使用tf.data.API.
steps_per_epoch此外,如果您使用方法,則必須為某個數字指定引數,repeat()否則意味著永遠運行。你可以steps_per_epoch在這里查看。
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/415556.html
標籤:
