我試圖適應tf.data.Dataset如下:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
INPUT_NEURONS = 10
OUTPUT_NEURONS = 1
features = tf.random.normal((1000, INPUT_NEURONS))
labels = tf.random.normal((1000, OUTPUT_NEURONS))
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
def build_model():
model = keras.Sequential(
[
layers.Dense(3, input_shape=[INPUT_NEURONS]),
layers.Dense(OUTPUT_NEURONS),
]
)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse'])
return model
model = build_model()
model.fit(dataset, epochs=2, verbose=2)
但是,我收到以下錯誤:
ValueError: Input 0 of layer sequential is incompatible with the layer: expected axis -1 of input shape to have value 10 but received input with shape (10, 1)
model.summary() 不過看起來不錯:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 3) 33
_________________________________________________________________
dense_1 (Dense) (None, 1) 4
=================================================================
Total params: 37
Trainable params: 37
Non-trainable params: 0
_________________________________________________________________
是Keras模型fit()實際上適合tf.data.Dataset?如果是這樣,我在這里做錯了什么?
uj5u.com熱心網友回復:
據我所知,使用批次進行訓練是可選的,在模型開發程序中使用或不使用的超引數
不完全是,不是可選的。TF-Keras 設計用于批處理。摘要中的第一個維度始終對應于batch_size,并None表示batch_size模型接受任何維度。
大多數時候,您希望您的模型接受任何批量大小。好吧,如果您使用statefulLSTM,那么您想要定義 static batch_size。
將資料放入其中后,tf.data.Dataset它們將沒有專門的批處理維度:
dataset.element_spec
>> (TensorSpec(shape=(10,), dtype=tf.float32, name=None),
TensorSpec(shape=(1,), dtype=tf.float32, name=None))
并且在使用時tf.data,batch_sizeinModel.fit()被忽略,因此應手動進行批處理。更具體地說,您可能不知道 atf.data.Dataset每次包含多少個元素。
在這種情況下,在創建資料集后進行批處理是沒有意義的(我會解釋):
dataset.batch(3).element_spec
>> (TensorSpec(shape=(None, 10), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None))
tf.data通常用于中大型資料集,因此batching在創建后將允許矢量化轉換。考慮以下場景:
您有 5M 行信號資料要應用 fft。如果你在 fft 行程之前不進行批處理,它會一個一個地應用。
您有 (100K) 影像資料集。您想應用一些轉換或一些操作。批處理資料集將允許更快的矢量化轉換并節省大量時間。
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/408850.html
標籤:
上一篇:無法創建物體框架代碼優先遷移
