我對 tensorflow/keras 很陌生,我找不到解決這個問題的方法。我有一個大約 4000 個 20 維向量的訓練資料集,每個向量描述一個檔案。我在以后的狀態中也有那些相同的檔案向量。我想從初始狀態預測檔案向量在最后的狀態。我使用余弦相似度將狀態 0 的檔案向量與其最終狀態進行了比較,得到了大約 0.5。目標是用一個簡單的模型來改進它。目前我正在做:
model = Sequential()
model.add(Dense(20, activation='relu', input_dim=20))
model.compile(optimizer='adam', loss='cosine_similarity', metrics [tf.keras.metrics.CosineSimilarity(axis=1)])
model.summary()
history = model.fit(input_train, y_train,
epochs=30,
batch_size=16,
validation_data=(input_test,y_test),
callbacks=[tbCallBack]
)
在 30 個 epoch 之后,這給了我 0.66 的驗證余弦相似度,所以我猜測這實際上確實提高了我的初始余弦相似度并至少產生了某種附加值。
然后我想看看這些預測是否有意義:
lol = np.asarray([0.0125064 , 0.01250269, 0.01250133, 0.01250481, 0.01250508,
0.0125009 , 0.0125009 , 0.01250437, 0.01250131, 0.01250181,
0.01250403, 0.0125038 , 0.01250372, 0.01250246, 0.01250183,
0.01250226, 0.01250294, 0.76244247, 0.01250485, 0.01250205])
model.predict([lol])
#model.predict(lol)
兩個預測版本都給我以下錯誤:
WARNING:tensorflow:Model was constructed with shape (None, 20) for input KerasTensor(type_spec=TensorSpec(shape=(None, 20), dtype=tf.float32, name='dense_69_input'), name='dense_69_input', description="created by layer 'dense_69_input'"), but it was called on an input with incompatible shape (None,).
有人知道如何解決這個問題嗎?另外,如果有人熟悉這種目標,這是正確的方法嗎?有什么我可以做的不同嗎?
很感謝任何形式的幫助!
uj5u.com熱心網友回復:
嘗試np.expand_dims將批量維度添加到您的陣列中:
import tensorflow as tf
import numpy as np
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(20, activation='relu', input_dim=20))
model.compile(optimizer='adam', loss='cosine_similarity', metrics= [tf.keras.metrics.CosineSimilarity(axis=1)])
model.summary()
input_train = tf.random.normal((5, 20))
y_train = tf.random.normal((5, 20))
history = model.fit(input_train, y_train,
epochs=1,
batch_size=2)
lol = np.asarray([0.0125064 , 0.01250269, 0.01250133, 0.01250481, 0.01250508,
0.0125009 , 0.0125009 , 0.01250437, 0.01250131, 0.01250181,
0.01250403, 0.0125038 , 0.01250372, 0.01250246, 0.01250183,
0.01250226, 0.01250294, 0.76244247, 0.01250485, 0.01250205])
lol = np.expand_dims(lol, axis=0)
model.predict(lol)
array([[0.0727988 , 0. , 0.3008919 , 0.00460427, 0. ,
0.01472487, 0.31665963, 0.11831823, 0. , 0.05261957,
0. , 0. , 0. , 0. , 0.13595472,
0.07765757, 0.09340346, 0. , 0. , 0. ]],
dtype=float32)
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/372846.html
上一篇:使用argmax誤讀預測
