用于處理資料樣本的代碼可能會變得凌亂且難以維護;理想情況下,我們希望資料集代碼和模型訓練代碼解耦(分離),以獲得更好的可讀性和模塊性,PyTorch提供了兩個data primitives:torch.utils.data.DataLoader 和 torch.utils.data.Dataset,允許你使用預加載的datasets和你自己的data,Dataset 存盤樣本及其對應的標簽,DataLoader 給 Dataset 包裝了一個迭代器,以便訪問樣本,
PyTorch庫提供了一些預加載的資料集(如FashionMNIST),它們是 torch.utils.data.Dataset 的子類,特定的資料對應特定的實作函式,它們可以用來原型化和基準化你的模型,你可以在這里查看它們:Image Datasets, Text Datasets, and Audio Datasets,
加載資料集
這是一個怎樣從TorchVision加載Fashion-MNIST資料集的例子,Fashion-MNIST來自于Zalando的文章,由60000張訓練樣本和10000張測驗樣本組成,每一個樣本包含一個28x28
的灰度圖片和對應的10類中的1個類的標簽,
我們用以下引數加載FashionMNIST Dataset
root是訓練/測驗資料的保存路徑train指定是訓練集還是測驗集download=True如果root中沒有,則從網上下載transform和target_transform指定樣本的變換
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = https://www.cnblogs.com/DeepRS/p/datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = https://www.cnblogs.com/DeepRS/p/datasets.FashionMNIST(
root='data',
train=False,
download=True
transform=ToTensor()
)
輸出:
點擊查看代碼
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
迭代和資料集可視化
我們可以像list一樣索引Datasets:training_data[index],使用 matplotlib 可視化一些訓練集的樣本,
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
# torch.squeeze():洗掉維數為1的維度
plt.imshow(img.squeeze(), cmap="gray")
plt.show()

創建自定義資料集
一個自定義的資料集類必須實作三個函式:init,len,getitem,查看下面的實作程序,FashionMNIST圖片保存在 img_dir,它們的標簽分別保存在一個CSV檔案(逗號分隔值檔案) annotations_file 中,
下一節,我們將分解每個函式做了什么的,
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 利用pandas讀取csv并轉換為DataFrame
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
init
一旦實體化Datase物件,函式__init__ 就會立即運行:初始化包含圖片的目錄,標簽檔案,以及兩個轉換(下一節有更詳細的介紹)
labels.csv類似這樣:
tshirt1.jpg, 0
tshirt2.jpg, 0
...
anleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 這里指定了列名
self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'labels'])
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
len
__len__ 函式回傳資料集的樣本數
例如:
def __len__(self):
return len(self.img_labels)
getitem
__getitem__函式加載和回傳資料集中給定索引 idx 的樣本,根據索引,它獲得了硬碟上圖片的位置,利用 read_image 轉換為tensor,在 self.img_labels ,從csv中檢索相應的標簽,并呼叫轉換函式(如果可用),回傳一個包含圖片和對應標簽張量的元組,
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
利用DataLoader為訓練準備你的資料
Dataset只能同時檢索一個樣本的資料特征和標簽,當訓練模型時,通常需要傳遞“minibatches”樣本,每一個epoch重復打亂資料減少過擬合,并使用Python的 multiprocessing 加速資料檢索,
DataLoader 是一個迭代器,
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
通過DataLoader迭代
我們已經將該資料集加載到 DataLoader,根據需要可以對資料集進行迭代,每次迭代回傳一個 train_features 和 train_labels 的batch(分別包含 batch_size=64的特征和標簽),因為我們指定了 shuffle=True, 在我們迭代完所有的batch之后,資料就會被打亂(為了對資料加載順序進行更細致的控制,參閱Samplers)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

輸出:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 7
延伸閱讀
- torch.utils.data API
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/422990.html
標籤:其他
上一篇:建議:我正在嘗試在我的應用程式中創建跟蹤時間功能,但我應該何時將資料發送到后端
下一篇:阿里云DataWorks介紹
