我試圖預測我在 keras / tensorflow 中的所有測驗批次,然后繪制一個混淆矩陣。當前BATCH_SIZE為:32
我的測驗資料集是使用來自大資料集的以下代碼生成的:
test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)
之后model.compile(),model.fit()我用這段代碼得到了預測和正確的標簽:
points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
此方法僅預測一批 --> 32 個預測。
有沒有辦法預測 keras / tensorflow 中的所有測驗批次?
提前致謝!
uj5u.com熱心網友回復:
model.predict您可以根據檔案將整個資料集傳遞給:
輸入樣本。它可能是:一個 Numpy 陣列(或類似陣列),或一個陣列串列(如果模型有多個輸入)。TensorFlow 張量或張量串列(如果模型有多個輸入)。一個 tf.data 資料集。生成器或 keras.utils.Sequence 實體。對迭代器型別(資料集、生成器、序列)的解包行為的更詳細描述在 Model.fit 的類似迭代器輸入的解包行為部分中給出。
points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
或與numpy:
points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/464824.html
下一篇:在張量流圖中執行查找
