主頁 > 後端開發 > 機器學習——果蔬分類

機器學習——果蔬分類

2022-12-24 06:27:34 後端開發

一、選題的背景

為了實作對水果和蔬菜的分類識別,收集了香蕉、蘋果、梨、葡萄、橙子、獼猴桃、西瓜、石榴、菠蘿、芒果、黃瓜、胡蘿卜、辣椒、洋蔥、馬鈴薯、檸檬、番茄、蘿卜、甜菜根、卷心菜、生菜、菠菜、大豆、花椰菜、甜椒、辣椒、蘿卜、玉米、甜玉米、紅薯、辣椒粉、生姜、大蒜、豌豆、茄子共36種果蔬的影像,該專案使用resnet18網路進行分類,

二、機器學習案例設計方案

1.本選題采用的機器學習案例(訓練集與測驗集)的來源描述

資料集來自百度AI studio平臺(https://aistudio.baidu.com/aistudio/datasetdetail/119023/0),共包含36種果蔬,每一個類別包括100張訓練影像,10張測驗影像和10張驗證影像,

2 采用的機器學習框架描述

本次使用的網路框架,主要用到了二維卷積、激活函式、最大池化、Dropout和全連接,下面將對搭建的網路模型進行解釋,

首先是一個二維卷積層,輸入通道數為3,輸出通道數為100,卷積核大小是3*3,填充大小是1*1,輸入通道數為3是因為這個是第一層卷積,輸入的是RGB影像,具有三個通道,輸出通道數量可以根據實際情況自定,填充是因為希望在卷積后,不要改變影像的尺寸,

在卷積層之后是一個RELU激活函式,如果不用激活函式,在這種情況下每一層輸出都是上層輸入的線性函式,容易驗證,無論神經網路有多少層,輸出都是輸入的線性組合,與沒有隱藏層效果相當,因此引入非線性函式作為激活函式,這樣深層神經網路就有意義了(不再是輸入的線性組合,可以逼近任意函式),最早的想法是sigmoid函式或者tanh函式,輸出有界,很容易充當下一層輸入,

引入RELU激活函式有以下三個原因:

第一,采用sigmoid等函式,算激活函式時(指數運算),計算量大,反向傳播求誤差梯度時,求導涉及除法,計算量相對大,而采用Relu激活函式,整個程序的計算量節省很多,

第二,對于深層網路,sigmoid函式反向傳播時,很容易就會出現 梯度消失 的情況(在sigmoid接近飽和區時,變換太緩慢,導數趨于0,這種情況會造成資訊丟失),從而無法完成深層網路的訓練,

第三,ReLu會使一部分神經元的輸出為0,這樣就造成了 網路的稀疏性,并且減少了引數的相互依存關系,緩解了過擬合問題的發生,

然后再跟一個二維卷積層,輸入通道數應該和上一層卷積的輸出通道數相同,所以設為100, 輸出通道數同樣根據實際情況設定,此處設為150,其他引數與第一層卷積相同,

后續每一個卷積層和全連接層后面都會跟一個RELU激活函式,所以后面不再敘述RELU激活函式層,

再之后添加一個2*2的最大池化層,該層用來縮減模型的大小,提高計算速度,同時提高所提取特征的魯棒性,

再經過三次卷積后,使用Flatten將二維Tensor拉平,變為一維Tensor,然后使用全連接層,通過多個全連接層后,使用dropout層隨機洗掉一些結點,該方法可以有效的避免網路過擬合,在最后一個全連接層的輸出對應需要分類的個數,

3.涉及到的技術難點與解決思路

下載的資料集沒有劃分訓練集、測驗集和驗證集,需要自己寫代碼完成劃分,在剛開始寫代碼的時候對于檔案路徑沒有搞清楚,沒有弄懂os.path.join方法如何使用,導致總是讀取不到影像,并且代碼還沒有報錯誤正常運行結束,但是查看劃分后的檔案夾里沒有資料,通過debug發現檔案的路徑出現問題,具體是windows下的/和\混用,導致不能正確的對路徑進行處理,在排除問題后統一使用\\,最終問題得到解決,

三、機器學習的實作步驟

(1)劃分資料集并進行縮放

 1 import os
 2 import glob
 3 import random
 4 import shutil
 5 from PIL import Image
 6 #對所有圖片進行RGB轉化,并且統一調整到一致大小,但不讓圖片發生變形或扭曲,劃分了訓練集和測驗集
 7 
 8 if __name__ == '__main__':
 9     test_split_ratio = 0.05 #百分之五的比例作為測驗集
10     desired_size = 128 # 圖片縮放后的統一大小
11     raw_path = './raw'
12 
13     #把多少個類別算出來,包括目錄也包括檔案
14     dirs = glob.glob(os.path.join(raw_path, '*'))
15     #進行過濾,只保留目錄,一共36個類別
16     dirs = [d for d in dirs if os.path.isdir(d)]
17 
18     print(f'Totally {len(dirs)} classes: {dirs}')
19 
20     for path in dirs:
21         # 對每個類別單獨處理
22 
23         #只保留類別名稱
24         path = path.split('/')[-1]
25         print(path)
26         #創建檔案夾
27         os.makedirs(f'train/{path}', exist_ok=True)
28         os.makedirs(f'test/{path}', exist_ok=True)
29 
30         #原始檔案夾當前類別的圖片進行匹配
31         files = glob.glob(os.path.join( path, '*.jpg'))
32         # print(raw_path, path)
33 
34         files += glob.glob(os.path.join( path, '*.JPG'))
35         files += glob.glob(os.path.join( path, '*.png'))
36 
37         random.shuffle(files)#原地shuffle,因為要取出來驗證集
38 
39         boundary = int(len(files)*test_split_ratio) # 訓練集和測驗集的邊界
40         
41         for i, file in enumerate(files):
42             img = Image.open(file).convert('RGB')
43 
44             old_size = img.size  
45 
46             ratio = float(desired_size)/max(old_size)
47 
48             new_size = tuple([int(x*ratio) for x in old_size])#等比例縮放
49 
50             im = img.resize(new_size, Image.ANTIALIAS)#后面的方法不會造成模糊
51 
52             new_im = Image.new("RGB", (desired_size, desired_size))
53 
54             #new_im在某個尺寸上更大,我們將舊圖片貼到上面
55             new_im.paste(im, ((desired_size-new_size[0])//2,
56                                 (desired_size-new_size[1])//2))
57 
58             assert new_im.mode == 'RGB'
59             
60             if i <= boundary:
61                 new_im.save(os.path.join(f'test/{path}', file.split('\\')[-1].split('.')[0]+'.jpg'))
62             else:
63                 new_im.save(os.path.join(f'train/{path}', file.split('\\')[-1].split('.')[0]+'.jpg'))
64 
65     test_files = glob.glob(os.path.join('test', '*', '*.jpg'))
66     train_files = glob.glob(os.path.join('train', '*', '*.jpg'))
67 
68     print(f'Totally {len(train_files)} files for training')
69     print(f'Totally {len(test_files)} files for test')

 

(2)影像預處理

包括隨即旋轉、隨機翻轉、裁剪等,并進行歸一化,

 1 #影像預處理
 2 train_dir = './train'
 3 val_dir = './test'
 4 test_dir = './test'
 5 classes0 = os.listdir(train_dir)
 6 classes=sorted(classes0)
 7 print(classes)
 8 train_transform=transforms.Compose([
 9         transforms.RandomRotation(10),      # 旋轉+/-10度
10         transforms.RandomHorizontalFlip(),  # 反轉50%的影像
11         transforms.Resize(40),              # 調整最短邊的大小
12         transforms.CenterCrop(40),          # 作物最長邊
13         transforms.ToTensor(),
14         transforms.Normalize([0.485, 0.456, 0.406],
15                              [0.229, 0.224, 0.225])
16 ])

1 #顯示影像
2 def show_image(img,label):
3     print('Label: ', trainset.classes[label], "("+str(label)+")")
4     plt.imshow(img.permute(1,2,0))
5     plt.show()
6 
7 show_image(*trainset[10])
8 show_image(*trainset[20])

 

 

 

(3)讀取資料

1 batch_size = 64
2 train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
3 val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
4 test_loader = DataLoader(test_ds, batch_size*2, num_workers=4, pin_memory=True)

(4)構建CNN模型

#構建CNN模型

 1 #構建CNN模型
 2 class CnnModel(ImageClassificationBase):
 3     def __init__(self):
 4         super().__init__()
 5         #cnn提取特征
 6         self.network = nn.Sequential(
 7             nn.Conv2d(3, 100, kernel_size=3, padding=1),#Conv2D層
 8             nn.ReLU(),
 9             nn.Conv2d(100, 150, kernel_size=3, stride=1, padding=1),
10             nn.ReLU(),
11             nn.MaxPool2d(2, 2), #池化層
12 
13             nn.Conv2d(150, 200, kernel_size=3, stride=1, padding=1),
14             nn.ReLU(),
15             nn.Conv2d(200, 200, kernel_size=3, stride=1, padding=1),
16             nn.ReLU(),
17             nn.MaxPool2d(2, 2), 
18 
19             nn.Conv2d(200, 250, kernel_size=3, stride=1, padding=1),
20             nn.ReLU(),
21             nn.Conv2d(250, 250, kernel_size=3, stride=1, padding=1),
22             nn.ReLU(),
23             nn.MaxPool2d(2, 2), 
24 
25             #全連接
26             nn.Flatten(), 
27             nn.Linear(6250, 256),  
28             nn.ReLU(),            
29             nn.Linear(256, 128),  
30             nn.ReLU(),            
31             nn.Linear(128, 64),           
32             nn.ReLU(),
33             nn.Linear(64, 32),
34             nn.ReLU(),
35             nn.Dropout(0.25),
36             nn.Linear(32, len(classes)))
37         
38     def forward(self, xb):
39         return self.network(xb)

(5)訓練網路

#訓練網路

 1 #訓練網路
 2 @torch.no_grad()
 3 def evaluate(model, val_loader):
 4     model.eval()
 5     outputs = [model.validation_step(batch) for batch in val_loader]
 6     return model.validation_epoch_end(outputs)
 7 
 8 def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
 9     history = []
10     optimizer = opt_func(model.parameters(), lr)
11     for epoch in range(epochs):
12         # 訓練階段
13         model.train()
14         train_losses = []
15         for batch in tqdm(train_loader,disable=True):
16             loss = model.training_step(batch)
17             train_losses.append(loss)
18             loss.backward()
19             optimizer.step()
20             optimizer.zero_grad()
21         # 驗證階段
22         result = evaluate(model, val_loader)
23         result['train_loss'] = torch.stack(train_losses).mean().item()
24         model.epoch_end(epoch, result)
25         history.append(result)
26     return history
27 
28 model = to_device(CnnModel(), device)
29 
30 history=[evaluate(model, val_loader)]
31 
32 num_epochs = 100
33 opt_func = torch.optim.Adam
34 lr = 0.001
35 
36 history+= fit(num_epochs, lr, model, train_dl, val_dl, opt_func)

 

 

(6)繪制損失函式和準確率圖

 1 def plot_accuracies(history):
 2     accuracies = [x['val_acc'] for x in history]
 3     plt.plot(accuracies, '-x')
 4     plt.xlabel('epoch')
 5     plt.ylabel('accuracy')
 6     plt.title('Accuracy vs. No. of epochs')
 7     plt.show()
 8     
 9 def plot_losses(history):
10     train_losses = [x.get('train_loss') for x in history]
11     val_losses = [x['val_loss'] for x in history]
12     plt.plot(train_losses, '-bx')
13     plt.plot(val_losses, '-rx')
14     plt.xlabel('epoch')
15     plt.ylabel('loss')
16     plt.legend(['Training', 'Validation'])
17     plt.title('Loss vs. No. of epochs')
18     plt.show()
19 
20 plot_accuracies(history)
21 plot_losses(history)
22 
23 evaluate(model, test_loader)

(7)預測

 1 #預測分類
 2     y_true=[]
 3     y_pred=[]
 4     with torch.no_grad():
 5         for test_data in test_loader:
 6             test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
 7             pred = model(test_images).argmax(dim=1)
 8             for i in range(len(pred)):
 9                 y_true.append(test_labels[i].item())
10                 y_pred.append(pred[i].item())
11 
12     from sklearn.metrics import classification_report
13     print(classification_report(y_true,y_pred,target_names=classes,digits=4))

(8)讀取圖片測驗

 1 import numpy as np
 2 from PIL import Image
 3 import matplotlib.pyplot as plt
 4 import torchvision.transforms as transforms
 5 
 6 def predict(img_path):
 7     img = Image.open(img_path)
 8     plt.imshow(img)
 9     plt.show()
10     img = img.resize((32,32))
11     img = transforms.ToTensor()(img)
12     img = img.unsqueeze(0)
13     img = img.to(device)
14     pred = model(img).argmax(dim=1)
15     print('預測結果為:',classes[pred.item()])
16     return classes[pred.item()]
17 
18 predict('./raw/apple/Image_1.jpg')

四、總結

在本次課程設計中,使用深度學習的方法實作了果蔬的36分類,相對來說分類數量是比較多的,在訓練了100個epoch以后,分類的準確率可以達到74.3%,通過對果蔬的分類,我明白了當訓練集的影像數量較少時,可以采用資料增強對原始影像進行處理,獲得更多的資料來增強網路的泛化能力,避免網路過擬合,資料增強的方法一般有隨機翻轉、隨即旋轉、隨即裁剪、明暗變化、高斯噪聲、椒鹽噪聲等,除此之外,對整個深度學習中影像分類的流程也有了一定的了解,從收集資料、對資料進行預處理、自己構建網路模型、訓練網路到最后的預測結果,加深了對影像分類程序的理解,希望在以后的學習中,可以學習更多深度學習的方法和應用,

五、全部代碼

  1 import os
  2 import glob
  3 import random
  4 import shutil
  5 from PIL import Image
  6 #對所有圖片進行RGB轉化,并且統一調整到一致大小,但不讓圖片發生變形或扭曲,劃分了訓練集和測驗集
  7 
  8 if __name__ == '__main__':
  9     test_split_ratio = 0.05 #百分之五的比例作為測驗集
 10     desired_size = 128 # 圖片縮放后的統一大小
 11     raw_path = './raw'
 12 
 13     #把多少個類別算出來,包括目錄也包括檔案
 14     dirs = glob.glob(os.path.join(raw_path, '*'))
 15     #進行過濾,只保留目錄,一共36個類別
 16     dirs = [d for d in dirs if os.path.isdir(d)]
 17 
 18     print(f'Totally {len(dirs)} classes: {dirs}')
 19 
 20     for path in dirs:
 21         # 對每個類別單獨處理
 22 
 23         #只保留類別名稱
 24         path = path.split('/')[-1]
 25         print(path)
 26         #創建檔案夾
 27         os.makedirs(f'train/{path}', exist_ok=True)
 28         os.makedirs(f'test/{path}', exist_ok=True)
 29 
 30         #原始檔案夾當前類別的圖片進行匹配
 31         files = glob.glob(os.path.join(raw_path, path, '*.jpg'))
 32         # print(raw_path, path)
 33 
 34         files += glob.glob(os.path.join(raw_path, path, '*.JPG'))
 35         files += glob.glob(os.path.join(raw_path, path, '*.png'))
 36 
 37         random.shuffle(files)#原地shuffle,因為要取出來驗證集
 38 
 39         boundary = int(len(files)*test_split_ratio) # 訓練集和測驗集的邊界
 40         
 41         for i, file in enumerate(files):
 42             img = Image.open(file).convert('RGB')
 43 
 44             old_size = img.size  
 45 
 46             ratio = float(desired_size)/max(old_size)
 47 
 48             new_size = tuple([int(x*ratio) for x in old_size])#等比例縮放
 49 
 50             im = img.resize(new_size, Image.ANTIALIAS)#后面的方法不會造成模糊
 51 
 52             new_im = Image.new("RGB", (desired_size, desired_size))
 53 
 54             #new_im在某個尺寸上更大,我們將舊圖片貼到上面
 55             new_im.paste(im, ((desired_size-new_size[0])//2,
 56                                 (desired_size-new_size[1])//2))
 57 
 58             assert new_im.mode == 'RGB'
 59             
 60             if i <= boundary:
 61                 new_im.save(os.path.join(f'test/{path}', file.split('/')[-1].split('.')[0]+'.jpg'))
 62             else:
 63                 new_im.save(os.path.join(f'train/{path}', file.split('/')[-1].split('.')[0]+'.jpg'))
 64 
 65     test_files = glob.glob(os.path.join('test', '*', '*.jpg'))
 66     train_files = glob.glob(os.path.join('train', '*', '*.jpg'))
 67 
 68 
 69     print(f'Totally {len(train_files)} files for training')
 70     print(f'Totally {len(test_files)} files for test')
 71 
 72 
 73 import os
 74 import random
 75 import numpy as np
 76 import pandas as pd
 77 import torch
 78 import torch.nn as nn
 79 import torch.nn.functional as F
 80 from tqdm.notebook import tqdm
 81 from torchvision import datasets, transforms, models 
 82 from torchvision.datasets import ImageFolder
 83 from torchvision.transforms import ToTensor
 84 from torchvision.utils import make_grid
 85 from torch.utils.data import random_split
 86 from torch.utils.data.dataloader import DataLoader
 87 import matplotlib.pyplot as plt
 88 
 89 if __name__ == '__main__':
 90     # 使用第2個GPU
 91     os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 92 
 93     #影像預處理
 94     train_dir = './train'
 95     val_dir = './test'
 96     test_dir = './test'
 97     classes0 = os.listdir(train_dir)
 98     classes=sorted(classes0)
 99     # print(classes)
100     train_transform=transforms.Compose([
101             transforms.RandomRotation(10),      # 旋轉+/-10度
102             transforms.RandomHorizontalFlip(),  # 反轉50%的影像
103             transforms.Resize(40),              # 調整最短邊的大小
104             transforms.CenterCrop(40),          # 作物最長邊
105             transforms.ToTensor(),
106             transforms.Normalize([0.485, 0.456, 0.406],
107                                 [0.229, 0.224, 0.225])
108     ])
109 
110     trainset = ImageFolder(train_dir, transform=train_transform)
111     valset = ImageFolder(val_dir, transform=train_transform)
112     testset = ImageFolder(test_dir, transform=train_transform)
113     # print(len(trainset))
114 
115     #查看資料集的一個影像形狀
116     img, label = trainset[10]
117     # print(img.shape)
118 
119     #顯示影像
120     def show_image(img,label):
121         print('Label: ', trainset.classes[label], "("+str(label)+")")
122         plt.imshow(img.permute(1,2,0))
123         plt.show()
124 
125     # show_image(*trainset[10])
126     # show_image(*trainset[20])
127 
128     torch.manual_seed(10)
129     train_size = len(trainset)
130     val_size = len(valset)
131     test_size = len(testset)
132 
133     train_ds=trainset
134     val_ds=valset
135     test_ds=testset
136     len(train_ds), len(val_ds), len(test_ds)   
137 
138     #讀取資料
139     batch_size = 64
140     train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
141     val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
142     test_loader = DataLoader(test_ds, batch_size*2, num_workers=4, pin_memory=True)
143 
144 
145     if __name__ == '__main__':
146         for images, labels in train_loader:
147             fig, ax = plt.subplots(figsize=(18,10))
148             ax.set_xticks([])
149             ax.set_yticks([])
150             ax.imshow(make_grid(images,nrow=16).permute(1,2,0))
151             break
152 
153 
154 
155     torch.cuda.is_available()
156 
157 
158     #選擇GPU或CPU
159     def get_default_device():
160         if torch.cuda.is_available():
161             return torch.device('cuda')
162         else:
163             return torch.device('cpu')
164 
165     #移動到所選的設備   
166     def to_device(data, device):
167         if isinstance(data, (list,tuple)):
168             return [to_device(x, device) for x in data]
169         return data.to(device, non_blocking=True)
170 
171     class DeviceDataLoader():
172         #包裝資料加載器以將資料移動到設備
173         def __init__(self, dl, device):
174             self.dl = dl
175             self.device = device
176             
177         def __iter__(self):
178             #將資料移動到設備后生成一批資料
179             for b in self.dl: 
180                 yield to_device(b, self.device)
181 
182         def __len__(self):
183             #分批次
184             return len(self.dl)
185 
186     device = get_default_device()
187 
188 
189     train_loader = DeviceDataLoader(train_loader, device)
190     val_loader = DeviceDataLoader(val_loader, device)
191     test_loader = DeviceDataLoader(test_loader, device)
192 
193     input_size = 3*40*40
194     output_size = 3
195 
196 
197 
198     def accuracy(outputs, labels):
199         _, preds = torch.max(outputs, dim=1)
200         return torch.tensor(torch.sum(preds == labels).item() / len(preds))
201 
202     #影像分類
203     class ImageClassificationBase(nn.Module):
204         def training_step(self, batch):
205             images, labels = batch 
206             out = self(images)                   # 生成預測
207             loss = F.cross_entropy(out, labels)  # 計算損失
208             return loss
209         
210         def validation_step(self, batch):
211             images, labels = batch 
212             out = self(images)                    # 生成預測
213             loss = F.cross_entropy(out, labels)   # 計算損失
214             acc = accuracy(out, labels)           # 計算精度
215             return {'val_loss': loss.detach(), 'val_acc': acc}
216             
217         def validation_epoch_end(self, outputs):
218             batch_losses = [x['val_loss'] for x in outputs]
219             epoch_loss = torch.stack(batch_losses).mean()   # 合并損失
220             batch_accs = [x['val_acc'] for x in outputs]
221             epoch_acc = torch.stack(batch_accs).mean()      # 結合精度
222             return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
223         
224         def epoch_end(self, epoch, result):
225             print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
226                 epoch, result['train_loss'], result['val_loss'], result['val_acc']))
227 
228     #構建CNN模型
229     class CnnModel(ImageClassificationBase):
230         def __init__(self):
231             super().__init__()
232             #cnn提取特征
233             self.network = nn.Sequential(
234                 nn.Conv2d(3, 100, kernel_size=3, padding=1),#Conv2D層
235                 nn.ReLU(),
236                 nn.Conv2d(100, 150, kernel_size=3, stride=1, padding=1),
237                 nn.ReLU(),
238                 nn.MaxPool2d(2, 2), #池化層
239 
240                 nn.Conv2d(150, 200, kernel_size=3, stride=1, padding=1),
241                 nn.ReLU(),
242                 nn.Conv2d(200, 200, kernel_size=3, stride=1, padding=1),
243                 nn.ReLU(),
244                 nn.MaxPool2d(2, 2), 
245 
246                 nn.Conv2d(200, 250, kernel_size=3, stride=1, padding=1),
247                 nn.ReLU(),
248                 nn.Conv2d(250, 250, kernel_size=3, stride=1, padding=1),
249                 nn.ReLU(),
250                 nn.MaxPool2d(2, 2), 
251 
252                 #全連接
253                 nn.Flatten(), 
254                 nn.Linear(6250, 256),  
255                 nn.ReLU(),            
256                 nn.Linear(256, 128),  
257                 nn.ReLU(),            
258                 nn.Linear(128, 64),           
259                 nn.ReLU(),
260                 nn.Linear(64, 32),
261                 nn.ReLU(),
262                 nn.Dropout(0.25),
263                 nn.Linear(32, len(classes)))
264             
265         def forward(self, xb):
266             return self.network(xb)
267 
268     # 將模型加載到GPU上去
269     model = CnnModel()
270 
271     # model.cuda()
272 
273     if __name__ == '__main__':
274         for images, labels in train_loader:
275             out = model(images)
276             print('images.shape:', images.shape)    
277             print('out.shape:', out.shape)
278             print('out[0]:', out[0])
279             break
280 
281     device = get_default_device()
282 
283     train_dl = DeviceDataLoader(train_loader, device)
284     val_dl = DeviceDataLoader(val_loader, device)
285     test_dl = DeviceDataLoader(test_loader, device)
286     to_device(model, device)
287 
288 
289     #訓練網路
290     def evaluate(model, val_loader):
291         model.eval()
292         outputs = [model.validation_step(batch) for batch in val_loader]
293         return model.validation_epoch_end(outputs)
294 
295     def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
296         history = []
297         optimizer = opt_func(model.parameters(), lr)
298         for epoch in range(epochs):
299             # 訓練階段
300             model.train()
301             train_losses = []
302             for batch in tqdm(train_loader,disable=True):
303                 loss = model.training_step(batch)
304                 train_losses.append(loss)
305                 loss.backward()
306                 optimizer.step()
307                 optimizer.zero_grad()
308             # 驗證階段
309             result = evaluate(model, val_loader)
310             result['train_loss'] = torch.stack(train_losses).mean().item()
311             model.epoch_end(epoch, result)
312             history.append(result)
313         return history
314 
315     model = to_device(CnnModel(), device)
316 
317 
318     history=[evaluate(model, val_loader)]
319     num_epochs = 5
320     opt_func = torch.optim.Adam
321     lr = 0.001
322 
323     history+= fit(num_epochs, lr, model, train_dl, val_dl, opt_func)
324 
325 
326     # # 繪制損失函式和準確率圖
327 
328     def plot_accuracies(history):
329         accuracies = [x['val_acc'] for x in history]
330         plt.plot(accuracies, '-x')
331         plt.xlabel('epoch')
332         plt.ylabel('accuracy')
333         plt.title('Accuracy vs. No. of epochs')
334         plt.show()
335         
336     def plot_losses(history):
337         train_losses = [x.get('train_loss') for x in history]
338         val_losses = [x['val_loss'] for x in history]
339         plt.plot(train_losses, '-bx')
340         plt.plot(val_losses, '-rx')
341         plt.xlabel('epoch')
342         plt.ylabel('loss')
343         plt.legend(['Training', 'Validation'])
344         plt.title('Loss vs. No. of epochs')
345         plt.show()
346 
347     plot_accuracies(history)
348     plot_losses(history)
349 
350     evaluate(model, test_loader)
351 
352 
353     #預測分類
354     y_true=[]
355     y_pred=[]
356     with torch.no_grad():
357         for test_data in test_loader:
358             test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
359             pred = model(test_images).argmax(dim=1)
360             for i in range(len(pred)):
361                 y_true.append(test_labels[i].item())
362                 y_pred.append(pred[i].item())
363 
364     from sklearn.metrics import classification_report
365     print(classification_report(y_true,y_pred,target_names=classes,digits=4))
366 
367     # 讀取圖片進行預測
368     import numpy as np
369     from PIL import Image
370     import matplotlib.pyplot as plt
371     import torchvision.transforms as transforms
372 
373     def predict(img_path):
374         img = Image.open(img_path)
375         plt.imshow(img)
376         plt.show()
377         img = img.resize((32,32))
378         img = transforms.ToTensor()(img)
379         img = img.unsqueeze(0)
380         img = img.to(device)
381         pred = model(img).argmax(dim=1)
382         print('預測結果為:',classes[pred.item()])
383         return classes[pred.item()]
384 
385     predict('./raw/apple/Image_1.jpg')

 

轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/540597.html

標籤:其他

上一篇:LeetCode刷題第七周

下一篇:Python從入門到精通(第2版)——pyuic5: error: no such option: -m的問題解決

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 【C++】Microsoft C++、C 和匯編程式檔案

    ......

    uj5u.com 2020-09-10 00:57:23 more
  • 例外宣告

    相比于斷言適用于排除邏輯上不可能存在的狀態,例外通常是用于邏輯上可能發生的錯誤。 例外宣告 Item 1:當函式不可能拋出例外或不能接受拋出例外時,使用noexcept 理由 如果不打算拋出例外的話,程式就會認為無法處理這種錯誤,并且應當盡早終止,如此可以有效地阻止例外的傳播與擴散。 示例 //不可 ......

    uj5u.com 2020-09-10 00:57:27 more
  • Codeforces 1400E Clear the Multiset(貪心 + 分治)

    鏈接:https://codeforces.com/problemset/problem/1400/E 來源:Codeforces 思路:給你一個陣列,現在你可以進行兩種操作,操作1:將一段沒有 0 的區間進行減一的操作,操作2:將 i 位置上的元素歸零。最終問:將這個陣列的全部元素歸零后操作的最少 ......

    uj5u.com 2020-09-10 00:57:30 more
  • UVA11610 【Reverse Prime】

    本人看到此題沒有翻譯,就附帶了一個自己的翻譯版本 思考 這一題,它的第一個要求是找出所有 $7$ 位反向質數及其質因數的個數。 我們應該需要質數篩篩選1~$10^{7}$的所有數,這里就不慢慢介紹了。但是,重讀題,我們突然發現反向質數都是 $7$ 位,而將它反過來后的數字卻是 $6$ 位數,這就說明 ......

    uj5u.com 2020-09-10 00:57:36 more
  • 統計區間素數數量

    1 #pragma GCC optimize(2) 2 #include <bits/stdc++.h> 3 using namespace std; 4 bool isprime[1000000010]; 5 vector<int> prime; 6 inline int getlist(int ......

    uj5u.com 2020-09-10 00:57:47 more
  • C/C++編程筆記:C++中的 const 變數詳解,教你正確認識const用法

    1、C中的const 1、區域const變數存放在堆疊區中,會分配記憶體(也就是說可以通過地址間接修改變數的值)。測驗代碼如下: 運行結果: 2、全域const變數存放在只讀資料段(不能通過地址修改,會發生寫入錯誤), 默認為外部聯編,可以給其他源檔案使用(需要用extern關鍵字修飾) 運行結果: ......

    uj5u.com 2020-09-10 00:58:04 more
  • 【C++犯錯記錄】VS2019 MFC添加資源不懂如何修改資源宏ID

    1. 首先在資源視圖中,添加資源 2. 點擊新添加的資源,復制自動生成的ID 3. 在解決方案資源管理器中找到Resource.h檔案,編輯,使用整個專案搜索和替換的方式快速替換 宏宣告 4. Ctrl+Shift+F 全域搜索,點擊查找全部,然后逐個替換 5. 為什么使用搜索替換而不使用屬性視窗直 ......

    uj5u.com 2020-09-10 00:59:11 more
  • 【C++犯錯記錄】VS2019 MFC不懂的批量添加資源

    1. 打開資源頭檔案Resource.h,在其中預先定義好宏 ID(不清楚其實ID值應該設定多少,可以先新建一個相同的資源項,再在這個資源的ID值的基礎上遞增即可) 2. 在資源視圖中選中專案資源,按F7編輯資源檔案,按 ID 型別 相對路徑的形式添加 資源。(別忘了先把檔案拷貝到專案中的res檔案 ......

    uj5u.com 2020-09-10 01:00:19 more
  • C/C++編程筆記:關于C++的參考型別,專供新手入門使用

    今天要講的是C++中我最喜歡的一個用法——參考,也叫別名。 參考就是給一個變數名取一個變數名,方便我們間接地使用這個變數。我們可以給一個變數創建N個參考,這N + 1個變數共享了同一塊記憶體區域。(參考型別的變數會占用記憶體空間,占用的記憶體空間的大小和指標型別的大小是相同的。雖然參考是一個物件的別名,但 ......

    uj5u.com 2020-09-10 01:00:22 more
  • 【C/C++編程筆記】從頭開始學習C ++:初學者完整指南

    眾所周知,C ++的學習曲線陡峭,但是花時間學習這種語言將為您的職業帶來奇跡,并使您與其他開發人員區分開。您會更輕松地學習新語言,形成真正的解決問題的技能,并在編程的基礎上打下堅實的基礎。 C ++將幫助您養成良好的編程習慣(即清晰一致的編碼風格,在撰寫代碼時注釋代碼,并限制類內部的可見性),并且由 ......

    uj5u.com 2020-09-10 01:00:41 more
最新发布
  • Rust中的智能指標:Box<T> Rc<T> Arc<T> Cell<T> RefCell<T> Weak

    Rust中的智能指標是什么 智能指標(smart pointers)是一類資料結構,是擁有資料所有權和額外功能的指標。是指標的進一步發展 指標(pointer)是一個包含記憶體地址的變數的通用概念。這個地址參考,或 ” 指向”(points at)一些其 他資料 。參考以 & 符號為標志并借用了他們所 ......

    uj5u.com 2023-04-20 07:24:10 more
  • Java的值傳遞和參考傳遞

    值傳遞不會改變本身,參考傳遞(如果傳遞的值需要實體化到堆里)如果發生修改了會改變本身。 1.基本資料型別都是值傳遞 package com.example.basic; public class Test { public static void main(String[] args) { int ......

    uj5u.com 2023-04-20 07:24:04 more
  • [2]SpinalHDL教程——Scala簡單入門

    第一個 Scala 程式 shell里面輸入 $ scala scala> 1 + 1 res0: Int = 2 scala> println("Hello World!") Hello World! 檔案形式 object HelloWorld { /* 這是我的第一個 Scala 程式 * 以 ......

    uj5u.com 2023-04-20 07:23:58 more
  • 理解函式指標和回呼函式

    理解 函式指標 指向函式的指標。比如: 理解函式指標的偽代碼 void (*p)(int type, char *data); // 定義一個函式指標p void func(int type, char *data); // 宣告一個函式func p = func; // 將指標p指向函式func ......

    uj5u.com 2023-04-20 07:23:52 more
  • Django筆記二十五之資料庫函式之日期函式

    本文首發于公眾號:Hunter后端 原文鏈接:Django筆記二十五之資料庫函式之日期函式 日期函式主要介紹兩個大類,Extract() 和 Trunc() Extract() 函式作用是提取日期,比如我們可以提取一個日期欄位的年份,月份,日等資料 Trunc() 的作用則是截取,比如 2022-0 ......

    uj5u.com 2023-04-20 07:23:45 more
  • 一天吃透JVM面試八股文

    什么是JVM? JVM,全稱Java Virtual Machine(Java虛擬機),是通過在實際的計算機上仿真模擬各種計算機功能來實作的。由一套位元組碼指令集、一組暫存器、一個堆疊、一個垃圾回收堆和一個存盤方法域等組成。JVM屏蔽了與作業系統平臺相關的資訊,使得Java程式只需要生成在Java虛擬機 ......

    uj5u.com 2023-04-20 07:23:31 more
  • 使用Java接入小程式訂閱訊息!

    更新完微信服務號的模板訊息之后,我又趕緊把微信小程式的訂閱訊息給實作了!之前我一直以為微信小程式也是要企業才能申請,沒想到小程式個人就能申請。 訊息推送平臺🔥推送下發【郵件】【短信】【微信服務號】【微信小程式】【企業微信】【釘釘】等訊息型別。 https://gitee.com/zhongfuch ......

    uj5u.com 2023-04-20 07:22:59 more
  • java -- 緩沖流、轉換流、序列化流

    緩沖流 緩沖流, 也叫高效流, 按照資料型別分類: 位元組緩沖流:BufferedInputStream,BufferedOutputStream 字符緩沖流:BufferedReader,BufferedWriter 緩沖流的基本原理,是在創建流物件時,會創建一個內置的默認大小的緩沖區陣列,通過緩沖 ......

    uj5u.com 2023-04-20 07:22:49 more
  • Java-SpringBoot-Range請求頭設定實作視頻分段傳輸

    老實說,人太懶了,現在基本都不喜歡寫筆記了,但是網上有關Range請求頭的文章都太水了 下面是抄的一段StackOverflow的代碼...自己大修改過的,寫的注釋挺全的,應該直接看得懂,就不解釋了 寫的不好...只是希望能給視頻網站開發的新手一點點幫助吧. 業務場景:視頻分段傳輸、視頻多段傳輸(理 ......

    uj5u.com 2023-04-20 07:22:42 more
  • Windows 10開發教程_編程入門自學教程_菜鳥教程-免費教程分享

    教程簡介 Windows 10開發入門教程 - 從簡單的步驟了解Windows 10開發,從基本到高級概念,包括簡介,UWP,第一個應用程式,商店,XAML控制元件,資料系結,XAML性能,自適應設計,自適應UI,自適應代碼,檔案管理,SQLite資料庫,應用程式到應用程式通信,應用程式本地化,應用程式 ......

    uj5u.com 2023-04-20 07:22:35 more