pytorch實作影像分類資料處理
- 1. pytorch庫自帶資料
- 資料預處理
- 資料生成
- 資料加載
- 2. 訓練自己的資料
- 生成資料集
- 資料預處理
- 資料加載
1. pytorch庫自帶資料
為了更好的理解,這里以CIFAR10資料集作為訓練和測驗資料集,
我們將使用CIFAR10資料集,它包含十個類別:
[‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’],
CIFAR-10 中的影像尺寸為3x32x32,也就是RGB的3層顏色
通道,每層通道內的尺寸為32x32,

資料預處理
😃CIFAR10資料集的輸出是范圍在[0,1]之間的 PILImage,即對每個類別的概率分布情況,所以我們需要通過ToTensor()把影像灰度范圍從(0-255)變換到(0-1)之間,并通過transform.Normalize()把(0-1)變換到(-1,1)
import torch
import torchvision
import torchvision.transforms as transforms
#定義三個通道的像素值 均值(mean)為0.5,方差(std)為0.5
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
資料生成
torchvision.datasets中包含了以下資料集
- MNIST
- COCO(用于影像標注和目標檢測)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
資料生成函式:
class torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
引數說明:
- root:保存資料集的目錄
- train:True= 訓練集, False = 測驗集
- download:True = 從互聯網上下載資料集,并把資料集放在root目錄下. 如果資料集之前下載過,就不用再重復下載,
- transform:對資料集預處理的函式
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True, transform=transform)
資料加載
資料加載函式:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
引數說明:
- dataset (Dataset):加載資料的資料集,
- batch_size (int, optional):每個batch加載多少個樣本(默認: 1),
- shuffle (bool, optional):設定為True時會在每個epoch重新打亂資料(默認: False).
- sampler (Sampler, optional):定義從資料集中提取樣本的策略,如果指定,則忽略shuffle引數,
- num_workers (int, optional):用多少個子行程加載資料,0表示資料將在主行程中加載(默認: 0)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False, num_workers=2)
2. 訓練自己的資料
由于pytorch庫中的資料集包含的種類比較匱乏,我們在實際的應用中往往還會對其他的事物做影像分類,因此需要自己的資料集影像來訓練,實作影像分類,
生成資料集
要想用自己的資料集進行影像分類或者其他計算機視覺應用,不是之前下載好圖片,進行訓練就行了🤣
首先第一步需要自己的影像資料集進行標注
標注影像需要用到標注工具,這里介紹一種最方便的:labelimg
安裝labelimg,只需要在終端運行
pip install labelimg
之后在終端運行如下代碼,即可開始對影像進行標注
(base) MacBook-Air ~ % labelimg

進行批量標注:
點擊打開檔案按鈕可以打開需要被標注的圖片的檔案夾,
點擊改變存放目錄按鈕可以打開標注檔案存放的檔案夾,
點擊w快捷鍵可以開始標注,標注完后需要保存

最后標注完成的影像,會生成一個標注檔案xml格式,
資料預處理
接下來就是對標注后的影像進行預處理,
首先創建一個檔案夾(這里按照官方的檔案夾名字命名😂)

- Annotations:存放標注xml檔案
- JPEGImages:存放圖片
- ImageSets:存放一個名為Main檔案夾,Main檔案夾用來存放后續生成的train.txt,val.txt,test.txt、trainval.txt(也可以只有train.txt和test.txt,根據個人需求看是否需要驗證集),這些檔案保存的內容為圖片的名字(沒有后綴格式)
- src:存放后續生成的train.txt,val.txt,test.txt、trainval.txt,但這里的的檔案內容是,對應每個圖片的絕對路徑+類別
- label:存放不同影像的標注檔案(感覺這個檔案沒有用😂)
生成Main里的檔案:
import os
import random
random.seed(0)
xmlfilepath='Annotations'
saveBasePath="ImageSets/Main/"
trainval_percent=1
train_percent=1
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
生成src里的檔案:
import xml.etree.ElementTree as ET
from os import getcwd
sets=['train','val','test','trainval']
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def convert_annotation(image_id, list_file):
in_file = open('Annotations/%s.xml'%(image_id), encoding='utf-8')
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
wd = getcwd()
for image_set in sets:
image_ids = open('ImageSets/Main/%s.txt'%(image_set), encoding='utf-8').read().strip().split()
list_file = open('src/%s.txt'%(image_set), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write('JPEGImages/%s.jpg'%(image_id))
#這里寫入的是圖片的絕對路徑
convert_annotation(image_id, list_file)
list_file.write('\n')
list_file.close()
資料加載
from PIL import Image
import torch
import torchvision.transforms as transforms
class MyDataset(torch.utils.data.Dataset): # 創類:MyDataset,繼承torch.utils.data.Dataset
def __init__(self, datatxt, transform=None):
super(MyDataset, self).__init__()
fh = open(datatxt, 'r') # 打開src中的txt檔案,讀取內容
imgs = []
for line in fh: # 按行回圈txt文本中的內容
line = line.rstrip() # 洗掉本行string字串末尾的指定字符
words = line.split() # 通過指定分隔符對字串進行切片,默認為所有的空字符,包括空格、換行、制表符等
imgs.append((words[0], int(words[1]))) # 把txt里的內容讀入imgs串列保存,words[0]是圖片資訊,words[1]是label
self.imgs = imgs
self.transform = transform
def __getitem__(self, index): # 按照索引讀取每個元素的具體內容
fn, label = self.imgs[index] # fn是圖片path
img = Image.open(fn).convert('RGB') # from PIL import Image
if self.transform is not None: # 是否進行transform
img = self.transform(img)
return img, label # return回哪些內容,在訓練時回圈讀取每個batch,就能獲得哪些內容
def __len__(self): # 它回傳的是資料集的長度,必須有
return len(self.imgs)
'''標準化、圖片變換'''
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
train_transforms = transforms.Compose([
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=stdv)])
train_data = MyDataset(datatxt='train.txt', transform=train_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/292149.html
標籤:其他
上一篇:Real-Time Super-Resolution System of 4K-Video Based on Deep Learning----代碼實踐
