環境:TensorFlow==1.14.0 python3.7
Inception 模型優勢
(1)分解成小卷積很有效,可以降低引數量,減輕過擬合,增加網路非線性的表達能力,
(2) 卷積網路從輸入到輸出,應該讓圖片尺寸逐漸減小,輸出通道數逐漸增加,即讓空間結構化,將空間資訊轉化為高階抽象的特征資訊,
(3) Inception Module用多個分支提取不同抽象程度的高階特征的思路很有效,可以豐富網路的表達能力
詳細可看inceptionV3論文《Rethinking the Inception Architecture for Computer Vision》
遷移學習
遷移學習在實際應用中的意義非常大,它可以將之前已學過的知識(模型引數)遷移到一項新的任務上,使學習效率大大的提高,我們知道,要訓練一個復雜的深度學習模型,成本是十分巨大的,而遷移學習可以大大的降低我們的訓練成本,在短時間內就能達到很好的效果,
這次展示的是基于垃圾分類訓練,資料集來自Garbage Classification (12 classes) | Kaggle
采用對ImageNet訓練過的權重做預訓練模型,只對softmax訓練
一、準備階段
下載retrain.py
tensorflow/hub: A library for transfer learning by reusing parts of TensorFlow models. (github.com)
https://github.com/tensorflow/hub在以下目錄可以找到

之前位置在

如果之后位置變化可以到README.md中查看
資料集準備:
自己拍照想要識別的物體,建議是純色背景或者單一背景
也可以到Kaggle: Your Machine Learning and Data Science Community
然后建立以下目錄

data放置資料集,如

tmp放置訓練程序中產生的檔案
二、訓練模型
在retrain.py同級目錄下打開cmd
輸入python retrain.py --image_dir data --how_many_training_steps 1000 --model_dir inception_dec_2015 --output_graph output_graph.pd --output_labels output_labels.txt --bottleneck_dir tmp\bottleneck --summaries_dir tmp\retrain_logs
這里data是上述的data檔案夾位置,建議都改為絕對位置,防止意外報錯
具體引數可以查看retrain.py

訓練結果

結束后可以在retrain.py同級目錄下

inception_dec_2015下是下載的ImageNet訓練模型,bottleneck是放置樣本描述檔案,retrain_logs放置訓練日志,可以通過TensorBoard查看,output_graph.pd就是我們訓練好的模型,output_labels.txt是標簽
三、模型測驗
# coding: UTF-8
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import random
# 結果陣列與output_labels.txt檔案中的順序要一致
res = [
'battery', 'biological', 'brown glass', 'cardboard', 'clothes',
'green glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white glass'
]
path = r'test.jpg'
with tf.gfile.FastGFile('output_graph.pd', 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name(
'final_result:0'
) # 獲取新模型最后的輸出節點叫做final_result,可以從tensorboard中的graph中看到,其中名字后面的’:’之后接數字為EndPoints索引值(An operation allocates memory for its outputs, which are available on endpoints :0, :1, etc, and you can think of each of these endpoints as a Tensor.),通常情況下為0,因為大部分operation都只有一個輸出,
image_data = tf.gfile.FastGFile(
path, 'rb').read() # Returns the contents of a file as a string.
predictions = sess.run(
softmax_tensor,
{'DecodeJpeg/contents:0': image_data
}) # tensorboard中的graph中可以看到DecodeJpeg/contents是模型的輸入變數名字
predictions = np.squeeze(predictions)
top_k = predictions.argsort()[-2:][::-1][0]
print(res[top_k])
運行后會列印置信度最高的物體標簽
舊版本retrain.py的只針對對inceptionV3進行遷移學習訓練,如果需要訓練其他網路結構,得更新retrain.py一些引數,新版本就可以對其他網路結構訓練,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/321497.html
標籤:其他
