主頁 >  其他 > 使用TensorFlow訓練自己的圖片并應用(CNN卷積神經網路)

使用TensorFlow訓練自己的圖片并應用(CNN卷積神經網路)

2021-08-16 09:49:42 其他

一、環境配置

Anaconda:4.10.3
Python:3.6.2
TensorFlow:1.9.0
在這里插入圖片描述

二、圖片準備

在這個小專案中,我們首先需要自己在網上收集四類圖片(每類圖片30張,一共120張),這些圖片的格式最好是統一的JPG格式,對于解析度來說沒有特定的要求,我們的專案在預處理中可以進行解析度統一化的預處理(也就是把每一張圖片變成一樣的解析度64*64),
不過要根據你自己的目錄把圖片放在上面,不然代碼可是找不到的,我把圖片放在了如圖這個地方,
在這里插入圖片描述
每一張圖片都需要整理分類到每一個檔案夾中,程式才可以正常找到,比如我把土豆放在這個potato檔案夾下,
在這里插入圖片描述

三、效果展示

在測驗代碼中點擊運行:
在這里插入圖片描述
便會出現要預測的圖片(圖片顯示不清是因為這個圖片的像素只有64*64).
接著,把圖片關閉,即可顯示出預測是potato(土豆)的可能性是0.984120,
在這里插入圖片描述

四、源代碼

(1)preprocessing.py(圖片預處理)

# 將原始圖片轉換成需要的大小,并將其保存
import os  
import tensorflow as tf  
from PIL import Image  
  
# 原始圖片的存盤位置 E:/python-run-env/train-test/train-data/generate-simple/
orig_picture = 'E:/python-run-env/train-test/train-data/generate-simple/'
 
# 生成圖片的存盤位置 E:/python-run-env/train-test/Re_train/image_data/inputdata/
gen_picture = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'
 
# 需要的識別型別
classes = {'cabbage','carrot','nori','potato'} 
 
# 樣本總數
num_samples = 120 
   
# 制作TFRecords資料  
def create_record():  
    writer = tf.python_io.TFRecordWriter("dishes_train.tfrecords")  
    for index, name in enumerate(classes):  
        class_path = orig_picture +"/"+ name+"/"  
        # os.listdir() 方法用于回傳指定的檔案夾包含的檔案或檔案夾的名字的串列,
        for img_name in os.listdir(class_path):  
            img_path = class_path + img_name  
            img = Image.open(img_path)  
            img = img.resize((64, 64))    # 設定需要轉換的圖片大小
            img_raw = img.tobytes()      # 將圖片轉化為原生bytes  
            print (index,img_raw)  
            example = tf.train.Example(  
               features=tf.train.Features(feature={  
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),  
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))  
               }))  
            writer.write(example.SerializeToString())  
    writer.close()  
    
def read_and_decode(filename):  
    # 創建檔案佇列,不限讀取的數量  
    filename_queue = tf.train.string_input_producer([filename])  
    # create a reader from file queue  
    reader = tf.TFRecordReader()  
    # reader從檔案佇列中讀入一個序列化的樣本  
    _, serialized_example = reader.read(filename_queue)  
    # get feature from serialized example  
    # 決議符號化的樣本  
    features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'label': tf.FixedLenFeature([], tf.int64),  
            'img_raw': tf.FixedLenFeature([], tf.string)  
        })  
    label = features['label']  
    img = features['img_raw']  
    img = tf.decode_raw(img, tf.uint8)  
    img = tf.reshape(img, [64, 64, 3])  
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5  
    label = tf.cast(label, tf.int32)  
    return img, label  
 

if __name__ == '__main__':  
    create_record()  
    batch = read_and_decode('dishes_train.tfrecords')  
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())  
      
    with tf.Session() as sess: # 開始一個會話    
        sess.run(init_op)    
        coord=tf.train.Coordinator()    
        threads= tf.train.start_queue_runners(coord=coord)  
        
        for i in range(num_samples):    
            example, lab = sess.run(batch)  # 在會話中取出image和label    
            img=Image.fromarray(example, 'RGB') # 這里Image是之前提到的 
            img.save(gen_picture+'/'+str(i)+'samples'+str(lab)+'.jpg')#存下圖片;注意cwd后邊加上‘/’    
            print(example, lab)    
        coord.request_stop()    
        coord.join(threads)   
        sess.close()  

點擊運行后,可以在終端看到很多輸出
在這里插入圖片描述
然后在這里可以看到很多圖片,要把這些圖片進行分類,然后裝到這些檔案夾里面:
在這里插入圖片描述
除此之外,還會產生一個TFrecord的二進制檔案:
在這里插入圖片描述
(2)batchdealing.py(輸入圖片處理)

import os
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
 

# -----------------生成圖片路徑和標簽的List------------------------------------
# 生成圖片的存盤位置 E:/python-run-env/train-test/Re_train/image_data/inputdata/
train_dir = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'
 
cabbage = []
label_cabbage = []
carrot = []
label_carrot = []
nori = []
label_nori = []
potato = []
label_potato = []
 
# step1:獲取'E:/Re_train/image_data/training_image'下所有的圖片路徑名,存放到
# 對應的串列中,同時貼上標簽,存放到label串列中,
# ratio是測驗集的比例
def get_files(file_dir, ratio):
    for file in os.listdir(file_dir+'/cabbage'):
        cabbage.append(file_dir +'/cabbage'+'/'+ file) 
        label_cabbage.append(0)
    for file in os.listdir(file_dir+'/carrot'):
        carrot.append(file_dir +'/carrot'+'/'+file)
        label_carrot.append(1)
    for file in os.listdir(file_dir+'/nori'):
        nori.append(file_dir +'/nori'+'/'+ file) 
        label_nori.append(2)
    for file in os.listdir(file_dir+'/potato'):
        potato.append(file_dir +'/potato'+'/'+file)
        label_potato.append(3)
 
    # step2:對生成的圖片路徑和標簽List做打亂處理把所有的資料合起來組成一個list(img和lab)
    # np.hstack水平(按列)按順序堆疊陣列,
    # >>> a = np.array((1,2,3))
    # >>> b = np.array((2,3,4))
    # >>> np.hstack((a,b))
    # array([1, 2, 3, 2, 3, 4])
    image_list = np.hstack((cabbage, carrot, nori, potato))
    label_list = np.hstack((label_cabbage, label_carrot, label_nori, label_potato))
 
    # 利用shuffle打亂順序
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)
    
    # 從打亂的temp中再取出list(img和lab)
    # image_list = list(temp[:, 0])
    # label_list = list(temp[:, 1])
    # label_list = [int(i) for i in label_list]
    # return image_list, label_list
    
    # 將所有的img和lab轉換成list
    all_image_list = list(temp[:, 0])
    all_label_list = list(temp[:, 1])
 
    # 將所得List分為兩部分,一部分用來訓練tra,一部分用來測驗val
    # ratio是測驗集的比例
    # n_sample全部樣本數
    n_sample = len(all_label_list)
    n_val = int(math.ceil(n_sample*ratio))   # 測驗樣本數
    n_train = n_sample - n_val   # 訓練樣本數
 
    # 訓練的圖片和標簽
    tra_images = all_image_list[0:n_train]
    tra_labels = all_label_list[0:n_train]
    tra_labels = [int(float(i)) for i in tra_labels]

    # 測驗圖片和標簽
    val_images = all_image_list[n_train:-1]
    val_labels = all_label_list[n_train:-1]
    val_labels = [int(float(i)) for i in val_labels]
 
    return tra_images, tra_labels, val_images, val_labels
    
    

# --------------------生成Batch----------------------------------------------
 
# step1:將上面生成的List傳入get_batch() ,轉換型別,產生一個輸入佇列queue,因為img和lab
# 是分開的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()從佇列中讀取影像
#   image_W, image_H, :設定好固定的影像高度和寬度
#   設定batch_size:每個batch要放多少張圖片
#   capacity:一個佇列最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):
    # 轉換型別
    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)
 
    # make an input queue
    # tf.train.slice_input_producer是一個tensor生成器,作用是按照設定,
    # 每次從一個tensor串列中按順序或者隨機抽取出一個tensor放入檔案名佇列,
    input_queue = tf.train.slice_input_producer([image, label])
 
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0]) # read img from a queue  
    
    # step2:將影像解碼,不同型別的影像不能混在一起,要么只用jpeg,要么只用png等,
    image = tf.image.decode_jpeg(image_contents, channels=3) 
    
    # step3:資料預處理,對影像進行旋轉、縮放、裁剪、歸一化等操作,讓計算出的模型更健壯,
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    image = tf.image.per_image_standardization(image)
 
    # step4:生成batch
    # image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32 
    # label_batch: 1D tensor [batch_size], dtype=tf.int32
    image_batch, label_batch = tf.train.batch([image, label],
                                                batch_size= batch_size,
                                                num_threads= 32, 
                                                capacity = capacity)
    # 重新排列label,行數為[batch_size]
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)
    return image_batch, label_batch            

(3)forward.py

 # 建立神經網路
import tensorflow as tf

# 網路結構定義
    # 輸入引數:images,image batch、4D tensor、tf.float32、[batch_size, width, height, channels]
    # 回傳引數:logits, float、 [batch_size, n_classes]
def inference(images, batch_size, n_classes):
    # 構建一個簡單的卷積神經網路,其中(卷積+池化層)x2,全連接層x2,最后一個softmax層做分類,

    # 卷積層1
    # 64個3x3的卷積核(3通道),padding=’SAME’,表示padding后卷積的圖與原圖尺寸一致,激活函式relu()
    # tf.variable_scope 可以讓變數有相同的命名,包括tf.get_variable得到的變數,還有tf.Variable變數
    # 它回傳的是一個用于定義創建variable(層)的op的背景關系管理器,
    with tf.variable_scope('conv1') as scope:
        
        # tf.truncated_normal截斷的產生正態分布的亂數,即亂數與均值的差值若大于兩倍的標準差,則重新生成,
        # shape,生成張量的維度 
        # mean,均值 
        # stddev,標準差
        weights = tf.Variable(tf.truncated_normal(shape=[3,3,3,64], stddev = 1.0, dtype = tf.float32), 
                              name = 'weights', dtype = tf.float32)
        
        # 偏置biases計算
        # shape表示生成張量的維度
        # 生成初始值為0.1的偏執biases
        biases = tf.Variable(tf.constant(value = 0.1, dtype = tf.float32, shape = [64]),
                             name = 'biases', dtype = tf.float32)
        
        # 卷積層計算
        # 輸入圖片x和所用卷積核w
        # x是對輸入的描述,是一個4階張量:
        # 比如:[batch,5,5,3]
        # 第一階給出一次喂入多少張圖片也就是batch
        # 第二階給出圖片的行解析度
        # 第三階給出圖片的列解析度
        # 第四階給出輸入的通道數
        # w是對卷積核的描述,也是一個4階張量:
        # 比如:[3,3,3,16]
        # 第一階第二階分別給出卷積行列解析度
        # 第三階是通道數
        # 第四階是有多少個卷積核
        # strides卷積核滑動步長:[1,1,1,1]
        # 第二階第三階表示橫向縱向滑動的步長
        # 第一和第四階固定為1
        # 使用0填充,所以padding值為SAME
        conv = tf.nn.conv2d(images, weights, strides=[1,1,1,1], padding='SAME')

        # 非線性激活
        # 對卷積后的conv1添加偏執,通過relu激活函式
        pre_activation = tf.nn.bias_add(conv, biases)

        conv1 = tf.nn.relu(pre_activation, name= scope.name)
        
    # 池化層1
    # 3x3最大池化,步長strides為2,池化后執行lrn()操作,區域回應歸一化,對訓練有利,
    # 最大池化層計算
    # x是對輸入的描述,是一個四階張量:
    # 比如:[batch,28,28,6]
    # 第一階給出一次喂入多少張圖片batch
    # 第二階給出圖片的行解析度
    # 第三階給出圖片的列解析度
    # 第四階輸入通道數
    # 池化核大小2*2的
    # 行列步長都是2
    # 使用SAME的padding
    with tf.variable_scope('pooling1_lrn') as scope:

        pool1 = tf.nn.max_pool(conv1, ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME', name='pooling1')

        # 區域回應歸一化函式tf.nn.lrn
        norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
 
    # 卷積層2
    # 16個3x3的卷積核(16通道),padding=’SAME’,表示padding后卷積的圖與原圖尺寸一致,激活函式relu()
    with tf.variable_scope('conv2') as scope:

        weights = tf.Variable(tf.truncated_normal(shape=[3,3,64,16], stddev = 0.1, dtype = tf.float32), 
                              name = 'weights', dtype = tf.float32)
        
        biases = tf.Variable(tf.constant(value = 0.1, dtype = tf.float32, shape = [16]),
                             name = 'biases', dtype = tf.float32)
        
        conv = tf.nn.conv2d(norm1, weights, strides = [1,1,1,1],padding='SAME')

        pre_activation = tf.nn.bias_add(conv, biases)

        conv2 = tf.nn.relu(pre_activation, name='conv2')
 
    # 池化層2
    # 3x3最大池化,步長strides為2,池化后執行lrn()操作,
    # pool2 and norm2
    with tf.variable_scope('pooling2_lrn') as scope:

        norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001/9.0,beta=0.75,name='norm2')

        pool2 = tf.nn.max_pool(norm2, ksize=[1,3,3,1], strides=[1,1,1,1],padding='SAME',name='pooling2')
 
    # 全連接層3
    # 128個神經元,將之前pool層的輸出reshape成一行,激活函式relu()
    with tf.variable_scope('local3') as scope:
        
        # 函式的作用是將tensor變換為引數shape的形式, 其中shape為一個串列形式,特殊的一點是串列中可以存在-1,
        # -1代表的含義是不用我們自己指定這一維的大小,函式會自動計算,但串列中只能存在一個-1,
        reshape = tf.reshape(pool2, shape=[batch_size, -1])

        # get_shape回傳的是一個元組
        dim = reshape.get_shape()[1].value

        weights = tf.Variable(tf.truncated_normal(shape=[dim,128], stddev = 0.005, dtype = tf.float32),
                             name = 'weights', dtype = tf.float32)
        
        biases = tf.Variable(tf.constant(value = 0.1, dtype = tf.float32, shape = [128]), 
                             name = 'biases', dtype=tf.float32)
        
        local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
        
    # 全連接層4
    # 128個神經元,激活函式relu() 
    with tf.variable_scope('local4') as scope:

        weights = tf.Variable(tf.truncated_normal(shape=[128,128], stddev = 0.005, dtype = tf.float32),
                              name = 'weights',dtype = tf.float32)
        
        biases = tf.Variable(tf.constant(value = 0.1, dtype = tf.float32, shape = [128]),
                             name = 'biases', dtype = tf.float32)
        
        local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4')
 
    # dropout層        
    #    with tf.variable_scope('dropout') as scope:
    #        drop_out = tf.nn.dropout(local4, 0.8)
            
        
    # Softmax回歸層
    # 將前面的FC層輸出,做一個線性回歸,計算出每一類的得分,在這里是2類,所以這個層輸出的是兩個得分,
    with tf.variable_scope('softmax_linear') as scope:

        weights = tf.Variable(tf.truncated_normal(shape=[128, n_classes], stddev = 0.005, dtype = tf.float32),
                              name = 'softmax_linear', dtype = tf.float32)
        
        biases = tf.Variable(tf.constant(value = 0.1, dtype = tf.float32, shape = [n_classes]),
                             name = 'biases', dtype = tf.float32)
        
        softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear')
 
    return softmax_linear
 

# loss計算
    # 傳入引數:logits,網路計算輸出值,labels,真實值,在這里是0或者1
    # 回傳引數:loss,損失值
def losses(logits, labels):

    with tf.variable_scope('loss') as scope:
        
        # 傳入的logits為神經網路輸出層的輸出,shape為[batch_size,num_classes],
        # 傳入的label為一個一維的vector,長度等于batch_size,
        # 每一個值的取值區間必須是[0,num_classes),其實每一個值就是代表了batch中對應樣本的類別
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='xentropy_per_example')

        # tf.reduce_mean 函式用于計算張量tensor沿著指定的數軸(tensor的某一維度)上的的平均值,
        # 主要用作降維或者計算tensor(影像)的平均值,
        loss = tf.reduce_mean(cross_entropy, name='loss')

        # tf.summary.scalar用來顯示標量資訊
        # 一般在畫loss,accuary時會用到這個函式,
        tf.summary.scalar(scope.name+'/loss', loss)

    return loss
 

# loss損失值優化
    # 輸入引數:loss,learning_rate,學習速率,
    # 回傳引數:train_op,訓練op,這個引數要輸入sess.run中讓模型去訓練,
def trainning(loss, learning_rate):

    with tf.name_scope('optimizer'):
        
        # tf.train.AdamOptimizer()函式是Adam優化演算法:是一個尋找全域最優點的優化演算法,引入了二次方梯度校正,
        # learning_rate:張量或浮點值,學習速率
        optimizer = tf.train.AdamOptimizer(learning_rate= learning_rate)

        global_step = tf.Variable(0, name='global_step', trainable=False)

        # minimize() 實際上包含了兩個步驟,即 compute_gradients 和 apply_gradients,
        # 前者用于計算梯度,后者用于使用計算得到的梯度來更新對應的variable
        train_op = optimizer.minimize(loss, global_step= global_step)

    return train_op
 

# 評價/準確率計算
    # 輸入引數:logits,網路計算值,labels,標簽,也就是真實值,在這里是0或者1,
    # 回傳引數:accuracy,當前step的平均準確率,也就是在這些batch中多少張圖片被正確分類了,
def evaluation(logits, labels):

    with tf.variable_scope('accuracy') as scope:

        # tf.nn.in_top_k用于計算預測的結果和實際結果的是否相等,并回傳一個bool型別的張量
        correct = tf.nn.in_top_k(logits, labels, 1)

        # tf.cast()函式的作用是執行 tensorflow 中張量資料型別轉換
        correct = tf.cast(correct, tf.float16)

        # tf.reduce_mean計算張量的各個維度的元素的量
        accuracy = tf.reduce_mean(correct)

        # tf.summary.scalar用來顯示標量資訊
        tf.summary.scalar(scope.name+'/accuracy', accuracy)

    return accuracy
 

(4)backward.py(訓練模型)

# 匯入檔案
import os
import numpy as np
import tensorflow as tf
import batchdealing
import forward
 
# 變數宣告
N_CLASSES = 4  # 4類 分別是:'cabbage','carrot','nori','potato'
IMG_W = 64   # resize影像,太大的話訓練時間久
IMG_H = 64
BATCH_SIZE =20 # 一次喂入多少
CAPACITY = 200  # 容量
MAX_STEP = 200 # 一般大于10K
learning_rate = 0.0001 # 一般小于0.0001
 
# 獲取批次batch  E:/python-run-env/train-test/Re_train/image_data/inputdata/
train_dir = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'   # 訓練樣本的讀入路徑
logs_train_dir = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'    # logs存盤路徑
# logs_test_dir =  'E:/Re_train/image_data/test'        # logs存盤路徑
 
# train, train_label = batchdealing.get_files(train_dir)
train, train_label, val, val_label = batchdealing.get_files(train_dir, 0.3)
# 訓練資料及標簽
train_batch,train_label_batch = batchdealing.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 測驗資料及標簽
val_batch, val_label_batch = batchdealing.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY) 
 
# 訓練操作定義
train_logits = forward.inference(train_batch, BATCH_SIZE, N_CLASSES)
train_loss = forward.losses(train_logits, train_label_batch)        
train_op = forward.trainning(train_loss, learning_rate)
train_acc = forward.evaluation(train_logits, train_label_batch)
 
# 測驗操作定義
test_logits = forward.inference(val_batch, BATCH_SIZE, N_CLASSES)
test_loss = forward.losses(test_logits, val_label_batch)        
test_acc = forward.evaluation(test_logits, val_label_batch)
 
# 這個是log匯總記錄
summary_op = tf.summary.merge_all() 
 
# 產生一個會話
sess = tf.Session()  

# 產生一個writer來寫log檔案
train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) 

# val_writer = tf.summary.FileWriter(logs_test_dir, sess.graph) 
# 產生一個saver來存盤訓練好的模型
saver = tf.train.Saver()

# 所有節點初始化
sess.run(tf.global_variables_initializer())  

# 佇列監控
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 
# 進行batch的訓練
try:
    # 執行MAX_STEP步的訓練,一步一個batch
    for step in np.arange(MAX_STEP):
        if coord.should_stop():
            break
        # 啟動以下操作節點
        _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])
        
        # 每隔50步列印一次當前的loss以及acc,同時記錄log,寫入writer   
        if step % 10  == 0:
            print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0))
            summary_str = sess.run(summary_op)
            train_writer.add_summary(summary_str, step)
        # 每隔100步,保存一次訓練好的模型
        if (step + 1) == MAX_STEP:
            checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)
       
except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
 
finally:
    coord.request_stop()

訓練結束后,在終端顯示:
在這里插入圖片描述
(5)test.py

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import forward
from batchdealing import get_files
 
# 獲取一張圖片
def get_one_image(train):
    # 輸入引數:train,訓練圖片的路徑
    # 回傳引數:image,從訓練圖片中隨機抽取一張圖片
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]   # 隨機選擇測驗的圖片
 
    img = Image.open(img_dir)

    # 顯示圖片,在jupyter notebook下當然也可以不用plt.show()
    plt.imshow(img)
    plt.show(img)
    imag = img.resize([64, 64])  # 由于圖片在預處理階段以及resize,因此該命令可略
    image = np.array(imag)
    return image
 
# 測驗圖片
def evaluate_one_image(image_array):
    with tf.Graph().as_default():
       BATCH_SIZE = 1
       N_CLASSES = 4
 
       image = tf.cast(image_array, tf.float32)

       # 線性縮放影像以具有零均值和單位范數,
       image = tf.image.per_image_standardization(image)
       image = tf.reshape(image, [1, 64, 64, 3])
 
       # 構建卷積神經網路
       logit = forward.inference(image, BATCH_SIZE, N_CLASSES)
 
       # softmax函式的作用就是歸一化
       # 輸入: 全連接層(往往是模型的最后一層)的值,一般代碼中叫做logits,
       # 輸出: 歸一化的值,含義是屬于該位置的概率,一般代碼叫做probs,
       logit = tf.nn.softmax(logit)
 
       x = tf.placeholder(tf.float32, shape=[64, 64, 3])
 
       # you need to change the directories to yours. E:/python-run-env/train-test/Re_train/image_data/inputdata/
       logs_train_dir = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'
 
       # tf.train.Saver() 保存和加載模型
       saver = tf.train.Saver()
 
       with tf.Session() as sess:
 
           print("Reading checkpoints...")
           ckpt = tf.train.get_checkpoint_state(logs_train_dir)
           if ckpt and ckpt.model_checkpoint_path:
               global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
               saver.restore(sess, ckpt.model_checkpoint_path)
               print('Loading success, global_step is %s' % global_step)
           else:
               print('No checkpoint file found')

           # feed_dict的作用是給使用placeholder創建出來的tensor賦值,
           # 其實,他的作用更加廣泛:feed使用一個值臨時替換一個op的輸出結果,
           # 你可以提供feed資料作為run()呼叫的引數,
           # feed只在呼叫它的方法內有效,方法結束,feed就會消失,
           # 當我們構建完圖后,需要在一個會話中啟動圖,啟動的第一步是創建一個Session物件,
           # 為了取回(Fetch)操作的輸出內容,可以在使用Session物件的run()呼叫執行圖時,
           # 傳入一些tensor,這些tensor會幫助你取回結果,
           prediction = sess.run(logit, feed_dict={x: image_array})

           #  取出prediction中元素最大值所對應的索引,也就是最大的可能
           max_index = np.argmax(prediction)

           if max_index==0:
               print('This is a cabbage with possibility %.6f' %prediction[:, 0])
           elif max_index==1:
               print('This is a carrot with possibility %.6f' %prediction[:, 1])
           elif max_index==2:
               print('This is a nori with possibility %.6f' %prediction[:, 2])
           else:
               print('This is a potato with possibility %.6f' %prediction[:, 3])
 

               
if __name__ == '__main__':
    
    train_dir = 'E:/python-run-env/train-test/Re_train/image_data/inputdata/'
    train, train_label, val, val_label = get_files(train_dir, 0.3)
    img = get_one_image(val)  # 通過改變引數train or val,進而驗證訓練集或測驗集
    evaluate_one_image(img)

點擊測驗之后,就可以顯示如三效果展示那樣的效果了,每一次都是隨機選取我們自己的圖片來測驗的,

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/293872.html

標籤:AI

上一篇:深度學習100例-卷積神經網路(LeNet-5)深度學習里的“Hello Word” | 第22天

下一篇:什么是CPU,GPU,TPU,DPU,NPU,BPU

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more