PyTorch構造資料集(深度學習計算機視覺)
- 構造資料集
- 儲存圖片路徑與標簽的檔案
- 匯入所需的庫
- 自定義datasets類
- 定義load_data函式
- 呼叫函式即可構造
構造資料集
在之前的學習程序中,訓練模型用的資料集是pytorch自帶的MNIST等,對匯入資料集不求甚解,但當上手手勢識別專案的時候,我對資料的匯入一頭霧水,查找到的資料讓人“從入門到放棄”,自閉之后,重新研究了構造pytorch資料集的步驟,弄懂了其中一種相對簡單的方法,
儲存圖片路徑與標簽的檔案
以txt檔案為例,train.txt和test.txt檔案中每行都只有一張圖片的路徑、對應的標簽,test.txt檔案某片段如下,每行最后的數字是圖片的標簽,
test/2/hand1_2_top_seg_4_cropped.png 2
test/2/hand1_2_top_seg_3_cropped_0_9961.png 2
test/2/hand1_2_top_seg_3_cropped_0_7789.png 2
test/2/hand2_2_bot_seg_1_cropped.png 2
test/3/hand2_3_right_seg_1_cropped.png 3
test/3/hand5_3_bot_seg_2_cropped.png 3
可以先將圖片根據標簽儲存在不同檔案夾,在此基礎上,使用python腳本創建這類txt檔案,
(csv等檔案亦可,本文以txt為例)
匯入所需的庫
import torch
import torchvision
from PIL import Image
import sys
自定義datasets類
自定義Dataset的子類——datasets類,
要之后要用Dataloader所定義的datasets類的話,這個類得擁有三個必要的函式:init、getitem、len,
注意:這里image_path在txt檔案中path前加了’hand_gesture_data/’,保證image_path是正確的相對路徑,
class gesture_datasets(torch.utils.data.Dataset):
def __init__(self, txt_path, transform=None):
lines = open(txt_path, 'r')
imgs = []
for line in lines:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
def __getitem__(self, index):
image_path = 'hand_gesture_data/' + self.imgs[index][0]
label = self.imgs[index][1]
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
定義load_data函式
這里定義了一個函式用于load data,其中將datasets類實體化后,關鍵在于torch.utils.data.DataLoader的使用,
txt_path同樣是相對路徑,使用Dataloader時可以考慮多執行緒(暫不支持Windows),
def load_data_gesture(batch_size, resize=None):
transform = []
if resize:
transform.append(torchvision.transforms.Resize(size=resize))
transform.append(torchvision.transforms.ToTensor())
transform = torchvision.transforms.Compose(transform)
gesture_train = gesture_datasets(txt_path='hand_gesture_data/train/train.txt', transform=transform)
gesture_test = gesture_datasets(txt_path='hand_gesture_data/test/test.txt', transform=transform)
num_workers = 0 if sys.platform.startswith('win') else 4
train_iter = torch.utils.data.DataLoader(gesture_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(gesture_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter, test_iter
呼叫函式即可構造
train_iter, test_iter = load_data_gesture(batch_size=60, resize=(224, 224))
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/291737.html
標籤:其他
