#資料獲取,使用FashionMNIST資料集
import random
import sys
sys.path.append('..')
import utils
batch_size = 256
train_data,test_data = utils.load_data_fashion_mnist(batch_size) #多層感知機與前面介紹的多類邏輯回歸非常類似,主要區別就是我們在輸入層和輸出層之間插入了一個到多個隱含層 (全連接網路)
from mxnet import ndarray as nd
num_inputs = 28*28
num_outputs = 10 #10類
num_hidden = 256 #加入一個隱含層,他的隱含單元個數為256
weight_scale = .01 #用一個大概是0.01之間的亂數進行初始化
#第一次我們初始化一個W1,w1是我們的輸入,他需要的是說你的輸入到他的 輸入的維度是你的feature的輸入,輸出就是你的hidden的大小
w1 = nd.random_normal(shape=(num_inputs,num_hidden),scale=weight_scale)
b1 = nd.zeros(num_outputs)
#第二層卷積,你的輸入是前一層的輸出 ,輸出就是output
w2 = nd.random_normal(shape=(num_hidden,num_outputs),scale=weight_scale)
b2 = nd.zeros(num_outputs) #四個引數
params = [w1,b1,w2,b2]
#對每個引數加一個梯度,便于以后求梯度
for param in params: param.attach_grad()
'''對于多層神經網路來說比較重要的概念是有個激活函式,有激活函式是說,如果我們兩層什么都不做,就把第一層丟進第二層的情況下
是一個線性的東西
就是說,如果我們就用線性運算子來構造多層神經網路,那么整個模型仍然只是一個線性函式
這是因為 y_hat= x*w1*w2 = x*w3
這里w3 = w1*w2為了讓我們的模型可以擬合非線性函式,我們需要在層之間插入非線性的激活函式
這里我們使用的是ReLU
計算簡單
relu(x) = max(x,0)
'''
#激活函式
def relu(x): return nd.maximum(x,0)
def SGD(params,lr): for param in params: param[:] = param - lr*param.grad # 對比預測值和標簽值,計算準確率
def accuracy(y_prediction, y): return (y_prediction.argmax(axis=1) == y.astype('float32')).mean().asscalar()
def evaluate_accuracy(data_iterator, net): acc = 0. for data, label in data_iterator: output = net(data) acc += accuracy(output,label) return acc/len(data_iterator) #定義模型
'''
我們的模型就是將層(全連接)和激活函式 (ReLU)串起來
'''
def net(x): x = x.reshape((-1,num_inputs)) #x是一個28*28的圖片,我們需要將它變成矩陣 h1 = relu(nd.dot(x,w1)+b1) output = nd.dot(h1,w2)+b2 return output
#定義Softmax和交叉熵損失函式,直接使用gluon提供的函式
from mxnet import gluon
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() #訓練
from mxnet import autograd
learning_rate =.5
for epoch in range(5): train_loss = 0. train_acc = 0. for X, y in train_data: with autograd.record(): y_pre = net(X) loss= softmax_cross_entropy(y_pre, y) loss.backward() SGD(params,learning_rate/batch_size) train_loss += nd.mean(loss).asscalar() train_acc += accuracy(y_pre,y) test_acc = evaluate_accuracy(test_data,net) print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % ( epoch+1, train_loss / len(train_data), train_acc / len(train_data), test_acc))
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/85382.html
標籤:其他開發語言
上一篇:python安裝pip失敗
