import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #只列印error的資訊
def preprocess(x,y): #準備函式,對x,y進行資料轉換
x=tf.cast(x,dtype=tf.float32)/255
y=tf.cast(y,dtype=tf.int32)
return x,y
(x,y),(x_test,y_test)=datasets.fashion_mnist.load_data()
print(x.shape,y.shape)
batchsz=100
db=tf.data.Dataset.from_tensor_slices((x,y))
db=db.map(preprocess).shuffle(10000).batch(batchsz)
db_tset=tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_tset=db_tset.map(preprocess).batch(batchsz)
db_iter = iter(db)
sample=next(db_iter)
print("batch:",sample[0].shape,sample[1].shape)
model = Sequential([ #Squential容器,裝入的是串列
layers.Dense(256,activation=tf.nn.relu), #資料量是200960,是指[784*256]+[256]
layers.Dense(128,activation=tf.nn.relu), #資料量是32896,是指[256*64]+[128]
layers.Dense(64,activation=tf.nn.relu), #資料量是8256,是指[128*64]+[64]
layers.Dense(32,activation=tf.nn.relu), #資料量是2080,是指[64*32]+[32]
layers.Dense(10) #資料量是330,是指[32*10]+[10]
]) #總的引數量是244522,就是244522根連接,每一根連接是4個位元組,也就是大約100萬個位元組,再除以1000,大概是100k的單元
# model.build(input_shape=[None,28*28]) #給網路一個輸入的初始值
# model.summary() #列印網路結構
optimizer=optimizers.Adam(learning_rate=1e-3) #w=w-lr*grad 優化器
def main():
for epoch in range(50):
for step,(x,y) in enumerate(db):
x=tf.reshape(x,[-1,28*28]) #x:[b,28,28]=>[b,28*28]
with tf.GradientTape() as tape:
logits = model(x) #[b,784]=>[b,10]
y_onehot=tf.one_hot(y, depth=10)
loss_mse=tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
loss_ce=tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
loss_ce=tf.reduce_mean(loss_ce)
grads=tape.gradient(loss_ce,model.trainable_variables)
optimizer.apply_gradients(zip(grads,model.trainable_variables)) #將梯度與w引數進行對應,用optimizer進行原地更新
if step %100 ==0:
print(epoch,step,"loss:",float(loss_ce),float(loss_mse))
#test
total_correct=0
total_num=0
for x,y in db_tset:
x=tf.reshape(x,[-1,28*28])
logits = model(x)
#lofits => prob [b.10]
prob = tf.nn.softmax(logits,axis=1) #將實數范圍轉換為概率范圍,且總和為1
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32)
#pred:[b]
#y:[b]
#correct:[b] True:equal, False:not equal
correct=tf.equal(y,pred)
correct=tf.reduce_sum(tf.cast(correct,dtype=tf.int32))
total_correct+=int(correct)
total_num+= x.shape[0] #將所有的batch加入進去
acc=total_correct/total_num
print(epoch,"test acc:",acc)
if __name__ == '__main__':
main()
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/532586.html
標籤:其他
上一篇:聊聊訊息佇列(MQ)那些事
