本文基于AI研習社2020年10月30日發布的“影像場景分類挑戰賽”完成,主要是從一個小白如何去將一個影像分類的任務用代碼跑起來的角度寫的,也是自己的一個學習程序,
影像場景分類挑戰賽:https://god.yanxishe.com/97?from=god_home_list
拿到比賽題目時,首先做以下幾點觀察:
- 首先分析一下這個任務:就是一個將一些風景圖片正確分類的簡單任務,
- 其次看一下官方給的資料集和標簽:
- 資料都是世界各地的風景圖片,共有6類,buildings、street、forest、sea、mountain、glacier,訓練集有13627張圖片,測驗集有3407張圖片,圖片為RGB圖片,格式為jpg,
- 標簽檔案為csv檔案,內容為filename( '0.jpg' )和label( 'forest' ),兩列都是String型別的,
- 然后看一下提交結果檔案的格式:結果檔案為csv檔案,內容不需要Title,第一列為圖片序號(0),Int型別,第二列為類別名稱(‘street’),String型別,
分析完基本比賽內容和條件之后,就可以開始用代碼實作了,本文使用colab平臺實作, 實作步驟如下:
- 成功使用Google云盤和Colab(Colab是一個 Jupyter 筆記本環境,已經默認安裝好 pytorch,不需要進行任何設定就可以使用,并且完全在云端運行,使用方法可以參考 :https://www.cnblogs.com/lfri/p/10471852.html ,國內目前無法訪問 colab,可以安裝一些軟體實作訪問,比如Ghelper: http://googlehelper.net/ )
- 將官方提供的資料集壓縮包上傳到Google云盤中(注意是上傳壓縮包,不要解壓以后上傳,解壓之后上傳很慢)
- 代碼實作(前序作業):
- 掛載Google Drive (在Colab中將Google云盤載入進來)
- 解壓檔案(解壓資料集壓縮包檔案到當前運行環境)
- 查看是否正在使用GPU
- 代碼實作(正式作業)
- 導包(匯入所有要用的包,在寫代碼程序中需要一個補充一個即可)
- 讀取標簽檔案(讀取訓練集的帶標簽檔案,此處為CSV格式檔案)
- 定義讀取資料集的類(包括訓練集和測驗集)
- 預處理(對資料集進行預處理)
- 呼叫讀取資料集的類(包括訓練集和測驗集)
- 定義模型
- 訓練(呼叫模型進行訓練)
- 測驗(使用訓練好的模型進行測驗,得到csv格式的結果檔案)
掛載Google Drive (在Colab中將Google云盤載入進來)
from google.colab import drive drive.mount('/content/drive')
解壓檔案(解壓資料集壓縮包檔案到當前運行環境)
!cp -r /content/drive/My\ Drive/Scene/Image_Classification.zip ./ #將google云盤中的資料集壓縮檔案拷貝到當前運行環境 !unzip Image_Classification.zip #將資料集壓縮檔案解壓,在當前運行環境得到'train'檔案夾、'test'檔案夾和'train.csv'檔案
查看是否正在使用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "CPU") print(device)
導包(匯入所有要用的包,在寫代碼程序中需要一個補充一個即可)
import torch import pandas as pd from PIL import Image from torch.utils.data import random_split
讀取標簽檔案(讀取訓練集的帶標簽檔案,此處為CSV格式檔案)
def readLabelFile(): label_file = pd.read_csv("train.csv") return label_file['filename'],label_file['label'] filename,filelabel = readLabelFile() map = ['buildings', 'street', 'forest', 'sea', 'mountain', 'glacier'] #將label中的字串轉換為數字 for i in range(len(map)): filelabel[filelabel==map[i]] = i #將物件轉換為串列 filename = filename.values filelabel = filelabel.values
定義讀取資料集的類(包括訓練集和測驗集)
class TrainDataset(torch.utils.data.Dataset): def __init__(self, root, img_list, label_list, transform = None): self.root = root self.img_list = img_list self.label_list = label_list self.transform = transform def __getitem__(self, index): img = Image.open(self.root + self.img_list[index]).convert('RGB') label = self.label_list[index] if self.transform: img = self.transform(img) return img,label def __len__(self): return len(self.img_list) class TestDataset(torch.utils.data.Dataset): def __init__(self, root, img_list, transform = None): self.root = root self.img_list = img_list self.transform = transform def __getitem__(self, index): img = Image.open(self.root + self.img_list[index]).convert('RGB') if self.transform: img = self.transform(img) return img,index def __len__(self): return len(self.img_list)
預處理(對資料集進行預處理)
transform = { 'train': transforms.Compose([ transforms.Resize((224, 224),interpolation=2), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]), 'val': transforms.Compose([ ]) }
呼叫讀取資料集的類(包括訓練集和測驗集)
train_dataset = TrainDataset('./train/', filename, filelabel, transform['train']) tra_dataset, val_dataset = random_split(train_dataset, [10000, 3627]) test_dataset = TestDataset('./test/', )
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/203949.html
標籤:其他
下一篇:尋高人
