主頁 >  其他 > 深度學習100例 | 第33天:遷移學習-實戰案例教程(必須掌握的一個點)

深度學習100例 | 第33天:遷移學習-實戰案例教程(必須掌握的一個點)

2021-10-19 08:12:55 其他

我的知乎我的微信公眾號我的CSDN下載本文原始碼+資料需要幫助

有料,有料,微信搜索 【K同學啊】 關注這個分享干貨的博主,
📍 本文 GitHub https://github.com/kzbkzb/Python-AI 已收錄,有 Python、深度學習的資料以及我的系列文章,


在本教程中,你將學習如何使用遷移學習通過預訓練網路對貓和狗的影像進行分類,

預訓練模型是一個之前基于大型資料集(通常是大型影像分類任務)訓練的已保存網路,

遷移學習通常應用在資料集過少以至于無法有效完成模型的訓練,故而尋求在預訓練模型的基礎上進行訓練、微調來解決這個問題,當然,即使資料集不那么小,我們也可以通過預訓練模型來加快模型的訓練,

在本文中,我們無需(重新)訓練整個模型,基礎卷積網路已經包含通常用于圖片分類的特征,但是,預訓練模型的最終分類部分特定于原始分類任務,隨后特定于訓練模型所使用的類集,

  1. 微調:解凍已凍結模型庫的一些頂層,并共同訓練新添加的分類器層和基礎模型的最后幾層,這樣,我們便能“微調” base model 中的高階特征表示,以使其與特定任務(遷移后的任務)更相關,

將遵循通用的深度學習的作業流程,

  1. 檢查并理解資料
  2. 構建輸入流水線,在本例中使用 Keras ImageDataGenerator
  3. 構成模型
    • 加載預訓練的基礎模型(和預訓練權重)
    • 將分類層堆疊在頂部
  4. 訓練模型
  5. 評估模型

遷移學習和微調

    • 資料預處理
      • 資料下載
      • 配置資料集以提高性能
      • 使用資料擴充
      • 重新縮放像素值
    • 從預訓練卷積網路創建基礎模型
    • 特征提取
      • 凍結卷積基
      • 有關 BatchNormalization 層的重要說明
      • 添加分類頭
      • 編譯模型
      • 訓練模型
      • 學習曲線
    • 微調
      • 解凍模型的頂層
      • 編譯模型
      • 繼續訓練模型
      • 評估和預測
    • 總結

from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
import numpy as np
import os

#設定GPU顯存用量按需使用
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  
    tf.config.set_visible_devices([gpus[0]],"GPU")

#忽略警告資訊
import warnings
warnings.filterwarnings("ignore")             

資料預處理

資料下載

在本教程中,你將使用包含數千個貓和狗影像的資料集,下載并解壓縮包含影像的 zip 檔案,然后使用 tf.keras.preprocessing.image_dataset_from_directory 效用函式創建一個 tf.data.Dataset 進行訓練和驗證,

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 14s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

顯示訓練集中的前九個影像和標簽:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

由于原始資料集不包含測驗集,因此你需要創建一個,為此,請使用 tf.data.experimental.cardinality 確定驗證集中有多少批次的資料,然后將其中的 20% 移至測驗集,

val_batches        = tf.data.experimental.cardinality(validation_dataset)
test_dataset       = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

配置資料集以提高性能

使用緩沖預提取從磁盤加載影像,以免造成 I/O 阻塞,

AUTOTUNE = tf.data.AUTOTUNE

train_dataset      = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset       = test_dataset.prefetch(buffer_size=AUTOTUNE)

使用資料擴充

當你沒有較大的影像資料集時,最好將隨機但現實的轉換應用于訓練影像(例如旋轉或水平翻轉)來人為引入樣本多樣性,這有助于使模型暴露于訓練資料的不同方面并減少過擬合,你可以在此教程中詳細了解資料擴充,

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

注:當你呼叫 model.fit 時,這些層僅在訓練程序中才會處于有效狀態,在 model.evaulatemodel.fit 中的推斷模式下使用模型時,它們處于停用狀態,

我們將這些層重復應用于同一個影像,然后查看結果,

for image, _ in train_dataset.take(1):
    plt.figure(figsize=(10, 10))
    first_image = image[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
        plt.imshow(augmented_image[0] / 255)
        plt.axis('off')

重新縮放像素值

稍后,您將下載 tf.keras.applications.MobileNetV2 作為基礎模型,此模型期望像素值處于 [-1, 1] 范圍內,但此時,影像中的像素值處于 [0, 255] 范圍內,要重新縮放這些像素值,請使用模型隨附的預處理方法,

"""
關于tf.keras.applications.mobilenet_v2.preprocess_input

回傳值的官方原文:The inputs pixel values are scaled between -1 and 1, sample-wise.
函式功能:將像素值縮放到[-1,1]之間
"""
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

注:另外,您也可以使用 Rescaling 層將像素值從 [0,255] 重新縮放為 [-1, 1]

rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

"""
如果你想縮放到[0,1]之間,可以這樣寫
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(scale=1./255)
"""

從預訓練卷積網路創建基礎模型

您將根據 Google 開發的 MobileNet V2 模型來創建基礎模型,此模型已基于 ImageNet 資料集進行預訓練,ImageNet 資料集是一個包含 140 萬個影像和 1000 個類的大型資料集,ImageNet 是一個研究訓練資料集,具有各種各樣的類別,例如 jackfruitsyringe,此知識庫將幫助我們對特定資料集中的貓和狗進行分類,

首先,您需要選擇將 MobileNet V2 的哪一層用于特征提取,最后的分類層(在“頂部”,因為大多數機器學習模型的圖表是從下到上的)不是很有用,相反,您將按照常見做法依賴于展平操作之前的最后一層,此層被稱為“瓶頸層”,與最后一層/頂層相比,瓶頸層的特征保留了更多的通用性,

首先,實體化一個已預加載基于 ImageNet 訓練的權重的 MobileNet V2 模型,通過指定 include_top=False 引數,可以加載不包括頂部分類層的網路,這對于特征提取十分理想,

# 使用官方權重創建一個base model
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

此特征提取程式將每個 160x160x3 影像轉換為 5x5x1280 的特征塊,我們看看它對一批示例影像做了些什么:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

特征提取

在此步驟中,您將凍結在上一步中創建的卷積基,并用作特征提取程式,此外,您還可以在其頂部添加分類器以及訓練頂級分類器,

凍結卷積基

在編譯和訓練模型之前,凍結卷積基至關重要,凍結(通過設定 layer.trainable = False)可避免在訓練期間更新給定層中的權重,MobileNet V2 具有許多層,因此將整個模型的 trainable 標記設定為 False 會凍結所有這些層,

base_model.trainable = False

有關 BatchNormalization 層的重要說明

許多模型都包含 tf.keras.layers.BatchNormalization 層,此層是一個特例,應在微調的背景關系中采取預防措施,如本教程后面所示,

設定 layer.trainable = False 時,BatchNormalization 層將以推斷模式運行,并且不會更新其均值和方差統計資訊,

解凍包含 BatchNormalization 層的模型以進行微調時,應在呼叫 base model 時通過傳遞 training = False 來使 BatchNormalization 層保持在推斷模式下,否則,應用于不可訓練權重的更新將破壞模型已經學習到的內容,

# 列印base model結構
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
......
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

添加分類頭

要從特征塊生成預測,請使用 tf.keras.layers.GlobalAveragePooling2D 層在 5x5 空間位置內取平均值,以將特征轉換成每個影像一個向量(包含 1280 個元素),

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

應用 tf.keras.layers.Dense 層將這些特征轉換成每個影像一個預測,您在此處不需要激活函式,因為此預測將被視為 logit 或原始預測值,正數預測 1 類,負數預測 0 類,

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

通過使用 Keras 函式式 API 將資料擴充、重新縮放、base_model 和特征提取程式層鏈接在一起來構建模型,如前面所述,由于我們的模型包含 BatchNormalization 層,因此請使用 training = False,

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)

model = tf.keras.Model(inputs, outputs)

編譯模型

在訓練模型前,需要先編譯模型,由于存在兩個類,并且模型提供線性輸出,請將二進制交叉熵損失與 from_logits=True 結合使用,

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

MobileNet 中的 250 萬個引數被凍結,但在密集層中有 1200 個可訓練引數,它們分為兩個 tf.Variable 物件,即權重和偏差,

len(model.trainable_variables)
2

訓練模型

經過 10 個周期的訓練后,您應該在驗證集上看到約 94% 的準確率,

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 25ms/step - loss: 0.8702 - accuracy: 0.4022
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.87
initial accuracy: 0.40
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 36ms/step - loss: 0.7913 - accuracy: 0.5070 - val_loss: 0.6019 - val_accuracy: 0.5928
Epoch 2/10
63/63 [==============================] - 2s 32ms/step - loss: 0.5866 - accuracy: 0.6730 - val_loss: 0.4355 - val_accuracy: 0.7574
Epoch 3/10
63/63 [==============================] - 2s 32ms/step - loss: 0.4451 - accuracy: 0.7695 - val_loss: 0.3383 - val_accuracy: 0.8243
Epoch 4/10
63/63 [==============================] - 2s 33ms/step - loss: 0.3875 - accuracy: 0.8225 - val_loss: 0.2799 - val_accuracy: 0.8639
Epoch 5/10
63/63 [==============================] - 2s 33ms/step - loss: 0.3350 - accuracy: 0.8345 - val_loss: 0.2273 - val_accuracy: 0.9097
Epoch 6/10
63/63 [==============================] - 2s 35ms/step - loss: 0.3062 - accuracy: 0.8640 - val_loss: 0.2028 - val_accuracy: 0.9097
Epoch 7/10
63/63 [==============================] - 2s 32ms/step - loss: 0.2765 - accuracy: 0.8840 - val_loss: 0.1758 - val_accuracy: 0.9319
Epoch 8/10
63/63 [==============================] - 2s 32ms/step - loss: 0.2538 - accuracy: 0.8925 - val_loss: 0.1613 - val_accuracy: 0.9418
Epoch 9/10
63/63 [==============================] - 2s 32ms/step - loss: 0.2478 - accuracy: 0.8845 - val_loss: 0.1472 - val_accuracy: 0.9455
Epoch 10/10
63/63 [==============================] - 2s 32ms/step - loss: 0.2301 - accuracy: 0.9015 - val_loss: 0.1351 - val_accuracy: 0.9493

學習曲線

我們看一下使用 MobileNet V2 基礎模型作為固定特征提取程式時訓練和驗證準確率/損失的學習曲線,

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

注:如果您想知道為什么驗證指標明顯優于訓練指標,主要原因是 tf.keras.layers.BatchNormalizationtf.keras.layers.Dropout 等層會影響訓練期間的準確率,在計算驗證損失時,它們處于關閉狀態,

在較小程度上,這也是因為訓練指標報告的是某個周期的平均值,而驗證指標則在經過該周期后才進行評估,因此驗證指標會看到訓練時間略長一些的模型,

微調

在特征提取實驗中,您僅在 MobileNet V2 基礎模型的頂部訓練了一些層,預訓練網路的權重在訓練程序中更新,

進一步提高性能的一種方式是在訓練(或“微調”)預訓練模型頂層的權重的同時,另外訓練您添加的分類器,訓練程序將強制權重從通用特征映射調整為專門與資料集相關聯的特征,

注:只有在您使用設定為不可訓練的預訓練模型訓練頂級分類器之后,才能嘗試這樣做,如果您在預訓練模型的頂部添加一個隨機初始化的分類器并嘗試共同訓練所有層,則梯度更新的幅度將過大(由于分類器的隨機權重所致),這將導致您的預訓練模型忘記它已經學習的內容,

另外,您還應嘗試微調少量頂層而不是整個 MobileNet 模型,在大多數卷積網路中,層越高,它的專門程度就越高,前幾層學習非常簡單且通用的特征,這些特征可以泛化到幾乎所有型別的影像,隨著您向上層移動,這些特征越來越特定于訓練模型所使用的資料集,微調的目標是使這些專用特征適應新的資料集,而不是覆寫通用學習,

解凍模型的頂層

您需要做的是解凍 base_model 并將底層設定為不可訓練,隨后,您應該重新編譯模型(使這些更改生效的必需操作),然后恢復訓練,

base_model.trainable = True
# 列印 base model 中的 layer 總數
print("Number of layers in the base model: ", len(base_model.layers))

# 微調 fine_tune_at 之后的layer
fine_tune_at = 100

# 凍結 fine_tune_at 之前的所有 layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable =  False
Number of layers in the base model:  154

編譯模型

當您正在訓練一個大得多的模型并且想要重新調整預訓練權重時,請務必在此階段使用較低的學習率,否則,您的模型可能會很快過擬合,

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

繼續訓練模型

如果你已提前訓練至收斂,則此步驟將使您的準確率提高幾個百分點,

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 8s 62ms/step - loss: 0.1459 - accuracy: 0.9345 - val_loss: 0.0524 - val_accuracy: 0.9814
Epoch 11/20
63/63 [==============================] - 3s 50ms/step - loss: 0.1244 - accuracy: 0.9495 - val_loss: 0.0416 - val_accuracy: 0.9864
Epoch 12/20
63/63 [==============================] - 3s 49ms/step - loss: 0.1027 - accuracy: 0.9570 - val_loss: 0.0463 - val_accuracy: 0.9777
Epoch 13/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0884 - accuracy: 0.9605 - val_loss: 0.0461 - val_accuracy: 0.9814
Epoch 14/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0939 - accuracy: 0.9585 - val_loss: 0.0434 - val_accuracy: 0.9814
Epoch 15/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0898 - accuracy: 0.9650 - val_loss: 0.0492 - val_accuracy: 0.9790
Epoch 16/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0796 - accuracy: 0.9650 - val_loss: 0.0353 - val_accuracy: 0.9889
Epoch 17/20
63/63 [==============================] - 3s 51ms/step - loss: 0.0834 - accuracy: 0.9670 - val_loss: 0.0425 - val_accuracy: 0.9864
Epoch 18/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0786 - accuracy: 0.9685 - val_loss: 0.0384 - val_accuracy: 0.9839
Epoch 19/20
63/63 [==============================] - 3s 50ms/step - loss: 0.0580 - accuracy: 0.9765 - val_loss: 0.0454 - val_accuracy: 0.9851
Epoch 20/20
63/63 [==============================] - 3s 51ms/step - loss: 0.0700 - accuracy: 0.9735 - val_loss: 0.0326 - val_accuracy: 0.9901

在微調 MobileNet V2 基礎模型的最后幾層并在這些層上訓練分類器時,我們來看一下訓練和驗證準確率/損失的學習曲線,驗證損失比訓練損失高得多,因此可能存在一些過擬合,

當新的訓練集相對較小且與原始 MobileNet V2 資料集相似時,也可能存在一些過擬合,

經過微調后,模型在驗證集上的準確率幾乎達到 98%,

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

評估和預測

最后,您可以使用測驗集在新資料上驗證模型的性能,

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 20ms/step - loss: 0.0204 - accuracy: 0.9948
Test accuracy : 0.9947916865348816

現在,你可以使用此模型來預測你的寵物是貓還是狗,

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image_batch[i].astype("uint8"))
    plt.title(class_names[predictions[i]])
    plt.axis("off")
Predictions:
 [1 1 1 0 0 0 0 0 1 0 1 0 1 0 0 0 0 1 0 1 1 1 0 1 0 0 1 0 1 0 0 0]
Labels:
 [1 1 1 0 0 0 0 0 1 0 1 0 1 0 0 0 0 1 0 1 1 1 0 1 0 0 1 0 1 0 0 0]

總結

  • 使用預訓練模型進行特征提取:使用小型資料集時,常見做法是利用基于相同域中的較大資料集訓練的模型所學習的特征,為此,您需要實體化預訓練模型并在頂部添加一個全連接分類器,預訓練模型處于“凍結狀態”,訓練程序中僅更新分類器的權重,在這種情況下,卷積基提取了與每個影像關聯的所有特征,而您剛剛訓練了一個根據給定的提取特征集確定影像類的分類器,

  • 微調預訓練模型:為了進一步提高性能,可能需要通過微調將預訓練模型的頂層重新用于新的資料集,在本例中,你調整了權重,以使模型學習特定于資料集的高級特征,當訓練資料集較大且與訓練預訓練模型所使用的原始資料集非常相似時,通常建議使用這種技術,

注:本文取自TensorFlow官網,做了部分修改

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/323274.html

標籤:其他

上一篇:OpenCV4.5.4 DNN人臉識別模塊使用介紹--如何快速搭建一個人臉識別系統

下一篇:《Python 深度學習》刷書筆記 Chapter 8 Part-3 神經網路風格轉移

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more