嘗試使用簡單的架構來識別手寫數字。測驗給出 0.9723 的準確度
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten
from sklearn.model_selection import train_test_split
# data split
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalizing
x_train = x_train / 255
x_test = x_test / 255
y_train_cat = keras.utils.to_categorical(y_train, 10)
y_test_cat = keras.utils.to_categorical(y_test, 10)
# creating model
model = keras.Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
x_train_split, x_val_split, y_train_split, y_val_split = train_test_split(x_train, y_train_cat, test_size=0.2)
model.fit(
x_train_split,
y_train_split,
batch_size=32,
epochs=6,
validation_data=(x_val_split, y_val_split))
# saving model
model.save('mnist_model.h5')
# test
model.evaluate(x_test, y_test_cat)
但是當我嘗試識別自己的數字(0 到 9)時,其中一些無法正確識別:

轉載請註明出處,本文鏈接:https://www.uj5u.com/net/408864.html
標籤:
上一篇:嘗試堆疊兩個參差不齊的張量時出錯
