我正在使用這個資料集開發一個 VAE 。我使用了 keras 教程代碼并開發了自己的 VAE。但是,當我運行 fit() 函式時,我得到:Invalid reduction dimension 1 for input with 1 dimensions. for '{{node Sum}} = Sum[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=false](Mean, Sum/reduction_indices)' with input shapes: [?], [2] and with computed input tensors: input[1] = <1 2>. What do I have to change?
代碼:
df = pd.read_csv('local path')
data, test_data = train_test_split(df, test_size=0.2)
data.shape #(227845, 31)
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean tf.exp(0.5 * z_log_var) * epsilon
編碼器:
latent_dim = 31
encoder_inputs = keras.Input(shape=(31))
x = layers.Dense(100, activation="relu") (encoder_inputs)
x = layers.Dense(100, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean,
z_log_var, z], name="encoder")
encoder.summary()
解碼器:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(100, activation="relu")(latent_inputs)
x = layers.Dense(100, activation="relu")(x)
decoder_outputs = layers.Dense(31, activation="sigmoid")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
這是我得到錯誤的地方:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(data, epochs=30, batch_size=128)
uj5u.com熱心網友回復:
錯誤來自tf.reduce_meanand tf.reduce_sum。在模型的train_step方法中VAE,更改此行:
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
對此:
reconstruction_loss = tf.reduce_mean(tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=-1),keepdims=True)
要么:
reconstruction_loss = tf.reduce_mean(tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=-1, keepdims=True))
它應該作業。
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/452267.html
上一篇:ValueError:尺寸必須相等,但對于具有輸入形狀[?,2]、[?,64]的'{{nodebinary_crossentropy/mul}},尺寸必須是2和64
