我想使用自定義資料加載器將 numpy 檔案傳輸到資料加載器。當我設定 transorm 時,我得到錯誤 TypeError: pic should be PIL Image or ndarray。得到 <class 'torch.Tensor'>
import os
import torch
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import transforms
class CustomTensorDataset(Dataset):
"""
TensorDataset with support for transforms
"""
def __init__(self, tensors, transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self, index):
x = self.tensors[0][index]
if self.transform:
x = self.transform(x)
y = self.tensors[1][index]
return x, y
def __len__(self):
return self.tensors[0].size(0)
te_data = torch.FloatTensor(np.ones([100, 3, 32, 32]))
te_targets = torch.FloatTensor(np.ones([100]))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=transform)
# testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=None) # --> no error
for item in testset_custom:
print(item)
uj5u.com熱心網友回復:
您對資料集的輸入資料需要是 PIL 影像或 numpy 陣列。但是,您的te_data和te_targets是torch.tensor。要解決這個問題,請不要將它們轉換為torch.tensor和之前提供給 Dataset 并保持它們的維度。資料集本身會改變其維度:
te_data = np.ones([100, 32, 32, 3])
te_targets = np.ones([100])
并且assert只要輸入是 numpy 陣列,還需要更改條件:
assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/439814.html
