我們在《torch.utils.data.DataLoader與迭代器轉換》中介紹了如何使用Pytorch內置的資料集進行論文實作,如torchvision.datasets,下面是加載內置訓練資料集的常見操作:
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
[ToTensor(),
Normalize((0.1307,), (0.3081,))
]
)
train_data = https://www.cnblogs.com/orion-orion/p/FashionMNIST(
root=RAW_DATA_PATH,
download=True,
train=True,
transform=transform
)
這里的train_data做為dataset物件,它擁有許多熟悉,我們可以通過以下方法獲取樣本資料的分類類別集合、樣本的特征維度、樣本的標簽集合等資訊,
classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets
print(classes)
print(num_features)
print(train_labels)
輸出如下:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])
但是,我們常常會在訓練集的基礎上拆分出驗證集(或者只用部分資料來進行訓練),我們想到的第一個方法是使用torch.utils.data.random_split對dataset進行劃分,下面我們假設劃分10000個樣本做為訓練集,其余樣本做為驗證集:
from torch.utils.data import random_split
k = 10000
train_data, valid_data = https://www.cnblogs.com/orion-orion/p/random_split(train_data, [k, len(train_data)-k])
注意我們如果列印train_data和valid_data的型別,可以看到顯示:
<class 'torch.utils.data.dataset.Subset'>
已經不再是torchvision.datasets.mnist.FashionMNIST物件,而是一個所謂的Subset物件!此時Subset物件雖然仍然還存有data屬性,但是內置的target和classes屬性已經不復存在,比如如果我們強行訪問valid_data的target屬性:
valid_target = valid_data.target
就會報如下錯誤:
'Subset' object has no attribute 'target'
但如果我們在后續的代碼中常常會將拆分后的資料集也默認為dataset物件,那么該如何做到代碼的一致性呢?
這里有一個trick,那就是以繼承SubSet類的方式的方式定義一個新的CustomSubSet類,使新類在保持SubSet類的基本屬性的基礎上,擁有和原本資料集類相似的屬性,如targets和classes等:
from torch.utils.data import Subset
class CustomSubset(Subset):
'''A custom subset class'''
def __init__(self, dataset, indices):
super().__init__(dataset, indices)
self.targets = dataset.targets # 保留targets屬性
self.classes = dataset.classes # 保留classes屬性
def __getitem__(self, idx): #同時支持索引訪問操作
x, y = self.dataset[self.indices[idx]]
return x, y
def __len__(self): # 同時支持取長度操作
return len(self.indices)
然后就引出了第二種劃分方法,即通過初始化CustomSubset物件的方式直接對資料集進行劃分(這里為了簡化省略了shuffle的步驟):
import numpy as np
from copy import deepcopy
origin_data = https://www.cnblogs.com/orion-orion/p/deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)
注意,CustomSubset類的初始化方法的第二個引數indices為樣本索引,我們可以通過np.arange()的方法來創建,
然后,我們再訪問valid_data對應的classes和targes屬性:
print(valid_data.classes)
print(valid_data.targets)
此時,我們發現可以成功訪問這些屬性了:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0, ..., 3, 0, 5])
當然,CustomSubset的作用并不只是添加資料集的屬性,我們還可以自定義一些資料預處理操作,我們將類的結構修改如下:
class CustomSubset(Subset):
'''A custom subset class with customizable data transformation'''
def __init__(self, dataset, indices, subset_transform=None):
super().__init__(dataset, indices)
self.targets = dataset.targets
self.classes = dataset.classes
self.subset_transform = subset_transform
def __getitem__(self, idx):
x, y = self.dataset[self.indices[idx]]
if self.subset_transform:
x = self.subset_transform(x)
return x, y
def __len__(self):
return len(self.indices)
我們可以在使用樣本前設定好資料預處理算子:
from torchvision import transforms
valid_data.subset_transform = transforms.Compose(\
[transforms.RandomRotation((180,180))])
這樣,我們再像下列這樣用索引訪問取出資料集樣本時,就會自動呼叫算子完成預處理操作:
print(valid_data[0])
列印結果縮略如下:
(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)
參考
- [1] https://pytorch.org/docs/stable/data.html?highlight=random_split#torch.utils.data.random_split
- [2] https://pytorch.org/docs/stable/data.html?highlight=subset#torch.utils.data.Subset
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/426405.html
標籤:其他
上一篇:分享你的見解與經驗|RocketMQ Summit 2022 議題征集中!
下一篇:資料結構與演算法學習之復雜度分析
