我創建了一個 y ~ x**2 的資料集

但是,當我訓練神經網路時,它無法擬合二次方程。

這是我的模型。
model2 = tf.keras.models.Sequential(
[tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(1)]
)
loss = tf.keras.losses.mse
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
model2.compile(optimizer=optimizer, loss=loss, metrics=tf.metrics.RootMeanSquaredError())
model2.fit(tf.expand_dims(X_train, -1), y_train, epochs=1000, verbose=1)
我對上述模型的思考程序是,我認為每次relu激活都會擬合一條區域線性線,然后慢慢將所有神經元連接起來形成一條二次線。
最后,我通過lambda x:x**2在輸出層上使用激活來管理它,但是,那是因為我知道該函式是 x**2。
所以我的問題是,在不知道真實函式的情況下,如何訓練神經網路以擬合非線性曲線?
uj5u.com熱心網友回復:
你的代碼對我來說很好。
請注意,我使用更大的學習率和提前停止(總共 2000 個 epoch 有 300 個耐心)。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
train_x = np.linspace(0, 80, 160)
train_y = train_x**2
test_x = np.linspace(80, 100, 40)
test_y = test_x**2
model2 = tf.keras.models.Sequential(
[tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(1)]
)
loss = tf.keras.losses.mse
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=300, restore_best_weights=True)
model2.compile(optimizer=optimizer, loss=loss, metrics=tf.metrics.RootMeanSquaredError())
model2.fit(tf.expand_dims(train_x, -1), train_y, epochs=2000, verbose=1, callbacks=[early_stop])
train_pred = model2.predict(train_x)
test_pred = model2.predict(test_x)
plt.scatter(train_x, train_y, c='blue', label='train x')
plt.scatter(test_x, test_y, c='green', label='test x')
plt.scatter(train_x, train_pred, c='red', label='train pred')
plt.scatter(test_x, test_pred, c='orange', label='test pred')
plt.legend()
plt.show()
訓練和測驗結果照片在這里
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/406684.html
標籤:
