1. 回圈神經網路
回圈神經網路(recurrent neural network,RNN)是一種序列模型,CNN是專門用來處理網格化資料(例如影像資料)的神經網路,而RNN是專門用來處理序列資料的神經網路,序列資料指的是跟序列相關的資料,比如一段語音、一首歌曲、一段文字、一段錄像等,
如圖,X1經過W1只影響輸出結果Y1,與其他事件相互獨立,

對于以下神經網路:

隱藏層ht接收的是上時刻的隱藏層 ht?1還是上時刻的輸出層yt?1, 可以分成了兩種 RNN:
Elman network 接收上時刻的隱藏層ht?1

Jordan network 接收上時刻的輸出層yt?1

其中,xt代表輸入,ht代表隱藏層,yt代表輸出;W、U、b代表引數矩陣和向量;σh和σy代表activation functions,
RNN中,由于梯度消失,每一次迭代都會使前一個輸入的隱藏層的“影響力”減弱一些,

2. 長短時記憶網路
長短時記憶網路(Long Short-Term Memory,LSTM)是一種含有LSTM區塊(blocks)或其他的一種類神經網路,因為它可以記憶不定時間長度的數值,區塊中有一個gate能夠決定input是否重要到能被記住及能不能被輸出output,下圖的最下方是輸入,最上方是輸出,

下圖為一個block的結構,


不同于RNN,LSTM可以用于處理長距離的依賴,
例:使用LSTM修改MNIST案例,
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 下載訓練集
train_dataset = datasets.MNIST(root='./',
train=True,
transform=transforms.ToTensor(),
download=True)
# 下載測驗集
test_dataset = datasets.MNIST(root='./',
train=False,
transform=transforms.ToTensor(),
download=True)
# 批次大小
batch_size = 64
# 裝載訓練集
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 裝載測驗集
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
for i, data in enumerate(train_loader):
# 獲得資料和對應的標簽
inputs, labels = data
print(inputs.shape)
print(labels.shape)
break
# 定義網路結構
class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__()
# 定義LSTM隱藏層結構
# 引數1:input_size 輸入特征的大小,一行資料為28 引數2:hidden_size LSTM模塊的數量64,相當于隱藏層中64個模塊 引數3:num_layers,LSTM隱藏層的層數1
# 引數4:batch_first,LSTM默認input(seq_len序列長度, batch批次大小, feature特征數量)、output(batch, seq_len,
# feature),均為3維資料,True按照這個格式
self.lstm = torch.nn.LSTM(
input_size=28,
hidden_size=64,
num_layers=1,
batch_first=True
)
# 全連接層 輸入64(隱藏層64個模塊),輸出10個分類
self.out = torch.nn.Linear(64, 10)
# 轉換為概率
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
# 轉換為3維(batch(64), seq_len, feature)
x = x.view(-1, 28, 28)
# output:[batch(64), seq_len(28), hidden_size(LSTM模塊數量64)],包含每個序列的輸出結果
# 雖然LSTM的batch_first為True,但是h_n,c_n的第0個維度還是num_layers
# h_n:block的輸出信號,[num_layers, batch, hidden_size]只包含最后一個序列的輸出結果
# c_n:中間cell的輸出信號,[num_layers, batch, hidden_size]只包含最后一個序列的輸出結果
output, (h_n, c_n) = self.lstm(x)
# batch的最后一層
output_in_last_timestep = h_n[-1, :, :]
x = self.out(output_in_last_timestep)
x = self.softmax(x)
return x
LR = 0.0003
# 定義模型
model = LSTM()
# 定義代價函式
entropy_loss = nn.CrossEntropyLoss()
# 定義優化器
optimizer = optim.Adam(model.parameters(), LR)
def train():
model.train()
for i, data in enumerate(train_loader):
# 獲得資料和對應的標簽
inputs, labels = data
# 獲得模型預測結果,(64,10)
out = model(inputs)
# 交叉熵代價函式out(batch,C),labels(batch)
loss = entropy_loss(out, labels)
# 梯度清0
optimizer.zero_grad()
# 計算梯度
loss.backward()
# 修改權值
optimizer.step()
def test():
model.eval()
correct = 0
for i, data in enumerate(test_loader):
# 獲得資料和對應的標簽
inputs, labels = data
# 獲得模型預測結果
out = model(inputs)
# 獲得最大值,以及最大值所在的位置
_, predicted = torch.max(out, 1)
# 預測正確的數量
correct += (predicted == labels).sum()
print("Test acc: {0}".format(correct.item() / len(test_dataset)))
correct = 0
for i, data in enumerate(train_loader):
# 獲得資料和對應的標簽
inputs, labels = data
# 獲得模型預測結果
out = model(inputs)
# 獲得最大值,以及最大值所在的位置
_, predicted = torch.max(out, 1)
# 預測正確的數量
correct += (predicted == labels).sum()
print("Train acc: {0}".format(correct.item() / len(train_dataset)))
for epoch in range(0, 10):
print('epoch:', epoch)
train()
test()
輸出:
torch.Size([64, 1, 28, 28])
torch.Size([64])
epoch: 0
Test acc: 0.789
Train acc: 0.7879666666666667
epoch: 1
Test acc: 0.8437
Train acc: 0.8388666666666666
epoch: 2
Test acc: 0.8442
Train acc: 0.8415833333333333
epoch: 3
Test acc: 0.9269
Train acc: 0.9267833333333333
epoch: 4
Test acc: 0.9359
Train acc: 0.9359666666666666
epoch: 5
Test acc: 0.9277
Train acc: 0.9306666666666666
epoch: 6
Test acc: 0.9437
Train acc: 0.94755
epoch: 7
Test acc: 0.9505
Train acc: 0.9513333333333334
epoch: 8
Test acc: 0.9469
Train acc: 0.94855
epoch: 9
Test acc: 0.947
Train acc: 0.95125
3. 門控回圈單元
門控回圈單元(Gated Recurrent Unit,GRU)效果跟LSTM差不多,但是用到的引數更少,將忘記門和輸入門合成了一個單一的更新門,


轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/294321.html
標籤:其他
