使用TensorFlow API:tf.keras 搭建神經網路
搭建神經網路六步法:
1.匯入第三方庫:import
2.匯入并理解資料,劃分訓練集與測驗集:train test
3.在Sequential()中搭建網路結構,逐層描述每層網路,相當于前向傳播,:model=tf.keras.models.Sequential
4.在compile中配置訓練方法,即選擇哪種優化器,選擇哪個損失函式,選擇哪種評測指標,model.compile
5.在fit中進行訓練,告知訓練集和測驗集的輸入特征和標簽,每個betch是多少,要迭代多少次資料集:model.fit
6.用model.summary列印出網路的結構和引數,
函式用法介紹
1.model=tf.keras.models.Sequential
Sequential 函式是一個容器,容器里封裝了神經網路的網路結構,描述了在Sequential函式的輸入引數從輸入層到輸出層的網路結構,
如:
拉直層:tf.keras.layers.Flatten()
拉直層可以變換張量的尺寸,把輸入特征拉直為一維陣列,是不含計算引數的層,
全連接層:tf.keras.layers.Dense( 神經元個數,activation=”激活函式”, kernel_regularizer=”正則化方式”)
其中:
activation(字串給出)可選 relu、softmax、sigmoid、tanh 等,kernel_regularizer 可選 tf.keras.regularizers.l1()、
tf.keras.regularizers.l2()
卷積層:tf.keras.layers.Conv2D( filter = 卷積核個數, kernel_size = 卷積核尺寸,
strides = 卷積步長,padding = “valid” or “same”)
LSTM 層:tf.keras.layers.LSTM(),
2.Model.compile
Compile 用于配置神經網路的訓練方法,告知訓練時使用的優化器、損失函式和準確率評測標準,
Model.compile( optimizer = 優化器, loss = 損失函式, metrics = [“準確率”])
(1)optimizer 可以是字串形式給出的優化器名字,也可以是函式形式,使用函式形式可以設定學習率、動量和超引數,
可選擇有:
‘sgd’or tf.optimizers.SGD( lr=學習率,decay=學習率衰減率,momentum=動量引數)
‘adagrad’or tf.keras.optimizers.Adagrad(lr=學習率,decay=學習率衰減率)
‘adadelta’or tf.keras.optimizers.Adadelta(lr=學習率, decay=學習率衰減率)
‘adam’or tf.keras.optimizers.Adam (lr=學習率, decay=學習率衰減率)
(2) Loss 可以是字串形式給出的損失函式的名字,也可以是函式形式,
可選項包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
損失函式常需要經過 softmax 等函式將輸出轉化為概率分布的形式,from_logits 則用來標注該損失函式是否需要轉換為概率的形式,取 False 時表示轉化為概率分布,取 True 時表示沒有轉化為概率分布,直接輸出,
(3)Metrics 標注網路評測指標,
可選項包括:
‘accuracy’:y_和 y 都是數值,如 y_=[1] y=[1],
‘categorical_accuracy’:y_和 y 都是以獨熱碼和概率分布表示,
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048],
‘sparse_ categorical_accuracy’:y_是以數值形式給出,y 是以獨熱碼形式
給出, 如 y_=[1],y=[0.256, 0.695, 0.048],
3.model.fit()
fit 函式用于執行訓練程序,
——model.fit(訓練集的輸入特征, 訓練集的標簽,batch_size, epochs, validation_data = (測驗集的輸入特征,測驗集的標簽), validataion_split = 從測驗集劃分多少比例給訓練集, validation_freq = 測驗的 epoch 間隔次數)
4.model.summary()
summary 函式用于列印網路結構和引數統計.
上圖是 model.summary()對鳶尾花分類網路的網路結構和引數統計,對于輸入為 4 輸出為 3 的全連接網路,共有 15 個引數,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/226630.html
標籤:AI
