李沐《動手學深度學習v2》學習筆記(三):Fashion-MNIST 資料集
目錄:
- 李沐《動手學深度學習v2》學習筆記(三):Fashion-MNIST 資料集
- 一、Fashion-MNIST 資料集簡介
- 二、了解資料集
- 三、批量顯示影像
一、Fashion-MNIST 資料集簡介
影像分類資料集介紹:MNIST 資料集(手寫數字資料集)是影像分類中廣泛使用的資料集之一,但作為基準資料集過于簡單,MNIST 資料集在簡單的模型上都可以有比較高的 A c c u r a c y \tt Accuracy Accuracy,不便于模型的驗證,因此我們將使用類似但更復雜的 Fashion-MNIST 資料集
先匯入相關庫:
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import torchvision # PyTorch 計算機視覺庫
from torch.utils import data
from torchvision import transforms # 資料操作庫
通過 t o r c h v i s i o n . t r a n s f o r m s \tt torchvision.transforms torchvision.transforms 中的內置函式將 Fashion-MNIST 資料集下載并讀取到記憶體中,若已下載,則只讀取:
trans = transforms.ToTensor()
將原始的 PILImage 格式或者 numpy.ndarray 格式(或 cv2 讀取的影像也可以)的資料格式化為可被 PyTorch 處理的張量型別,即 32 位浮點數格式;
將 shape 為 ( H , W , C ) (H, W, C) (H,W,C) 的 numpy.ndarray 或 PILImage 轉為 shape 為 ( C , H , W ) (C, H, W) (C,H,W) 的張量,其將每一個數值歸一化到 [ 0 , 1 ] [0,1] [0,1];
其歸一化方法比較簡單,直接除以255即可
return得到一個轉換物件
trans = transforms.ToTensor() # 將格式歸一化為張量
# 如果沒有,則下載訓練集,root:選擇下載/讀取路徑,傳入 trans
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True,
transform=trans, download=True)
# 如果沒有,則下載測驗集,root:選擇下載/讀取路徑,傳入 trans
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False,
transform=trans, download=True)
# 顯示訓練集和測驗集的大小
len(mnist_train), len(mnist_test)

圖片的形狀
(
c
h
a
n
n
e
l
,
h
e
i
g
h
t
,
w
i
d
t
h
)
\tt(channel,\ height,\ width)
(channel, height, width) 為
(
1
,
28
,
28
)
(1,28,28)
(1,28,28):

二、了解資料集
通過 d a t a . D a t a L o a d e r \tt data.DataLoader data.DataLoader 進行小批量劃分,同時顯示影像,進一步了解影像
train_iter = data.DataLoader(mnist_train, batch_size=18, shuffle=True)
image, label = next(iter(train_iter)) # 獲得一個小批量
plt.imshow(image[0].squeeze()) # 顯示批量中的圖片[0]
print('label:', label[0]) # 顯示標簽
plt.imshow(image[1].squeeze()) # 顯示批量中的圖片[1]
print('label:', label[1]) # 顯示標簽

grid = torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False,…) 將若干幅影像拼成一幅大影像,回傳值是一個張量
tensor( b a t c h _ s i z e , c h a n n e l , h e i g h t , w i d t h ) \tt (batch\_size,\ channel,\ height,\ width) (batch_size, channel, height, width),即 ( 批 量 數 , 通 道 數 , 高 , 寬 ) (批量數,\ 通道數,\ 高,\ 寬) (批量數, 通道數, 高, 寬),通道數如果是 1 1 1,會變為 3 3 3
nrow每一行顯示的影像數,default=8
padding子影像與子影像之間的間距,default=2
normalize是否歸一化到 [ 0 , 1 ] [0, 1] [0,1] 之間,default=False
回傳值張量, s h a p e = ( c h a n n e l , h e i g h t + 2 × p a d d i n g , w i d t h × b a t c h _ s i z e + ( b a t c h _ s i z e + 1 ) × p a d d i n g ) \tt shape=(channel,\ height+2×padding,\ width×batch\_size+(batch\_size+1)×padding) shape=(channel, height+2×padding, width×batch_size+(batch_size+1)×padding)
其中, 2 × p a d d i n g \tt2×padding 2×padding 表示上下邊界; ( b a t c h _ s i z e + 1 ) × p a d d i n g \tt (batch\_size+1)×padding (batch_size+1)×padding 表示左右邊界和中間邊界
import numpy as np
grid = torchvision.utils.make_grid(image, nrow=18)
print(grid.shape)
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(grid, (1, 2, 0)))
print(label)
注意,由于 plt.imshow 輸入的影像的格式為
(
h
e
i
g
h
t
,
w
i
d
t
h
,
c
h
a
n
n
e
l
)
\tt(height,\ width,\ channel)
(height, width, channel),而我們的 grid 的格式為
(
c
h
a
n
n
e
l
,
h
e
i
g
h
t
,
w
i
d
t
h
)
\tt(channel,\ height,\ width)
(channel, height, width),因此需要呼叫 np.transpose 來進行維度的轉換,它的用法與我們在PyTorch基礎與線性代數中講的 torch.transpose() 類似,只不過后者只能進行兩個維度間的互換

三、批量顯示影像
定義一個函式,該函式可以根據 label 的數值回傳該數值所指的名稱:
def get_fashion_mnist_labels(labels):
"""回傳Fashion-MNIST資料集的文本標簽"""
text_labels = [
't-shirt', 'trouser', 'pullover',
'dress', 'coat', 'sandal', 'shirt',
'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
Fashion-MNIST 資料集共有 10 10 10 種標簽👆
批量顯示圖片,加上上一例中的一組,這里再額外增加三組:
for i in range(3):
temp_image, temp_label = next(iter(train_iter))
temp_grid = torchvision.utils.make_grid(temp_image, nrow=18)
grid = torch.cat((grid, temp_grid), dim=1)
label = torch.cat((label, temp_label), dim=0)
plt.figure(figsize=(40, 8))
plt.imshow(np.transpose(grid, (1, 2, 0)))
print(get_fashion_mnist_labels(label))

改變 range(3) 的大小,可以顯示更多的影像,這里不再展開
參考資料:
[1]Pytorch教程(十六):FashionMNIST資料集DataSet DataLoader
[2]np.transpose(npimg, (1, 2, 0))的作用
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/352096.html
標籤:AI
上一篇:人工智能學習筆記 感知器演算法的設計實作 感知器準則實作多類分類
下一篇:Python機器學習之垃圾短信分類(用樸素貝葉斯演算法的伯努利模型和多項式模型分類垃圾短信資料集SMSSpamCollection.txt)
