TensorFlow Client對接模型服務
- 1. Tensorflow Client代碼撰寫對接Web
- 1.1 Client端代碼
- 2. 步驟程序
- 應用TensorFlow Serving Client完成對接模型服務撰寫以及運行
1. Tensorflow Client代碼撰寫對接Web
- main.py當中呼叫
# 獲取用戶上傳圖片
image = request.files.get('image')
if not image:
abort(400)
# 預測標記
result_img = make_prediction(image.read())
data = result_img.read()
result_img.close()
1.1 Client端代碼
需要用到tensorflow_serving.apis中的兩個模塊
from tensorflow_serving.apis import prediction_service_pb2_grpc
from tensorflow_serving.apis import predict_pb2
-
prediction_service_pb2_grpc
-
predict_pb2
- prediction.py檔案當中,定義make_prediction函式進行預測代碼邏輯
- 步驟分析
- 1、獲取讀取后臺讀取的圖片
- 2、圖片大小處理,轉換陣列
- 3、打開通道channel,構建stub,預測結果
- 4、predict_pb2進行預測請求創建
- 步驟分析
2. 步驟程序
- 1、獲取讀取后臺讀取的圖片,圖片大小處理,轉換陣列
def make_prediction(image):
"""
"""
def resize_img(image, target_size):
img = io.BytesIO()
img.write(image)
img = Image.open(img).convert("RGB")
if target_size:
img = img.resize((target_size[1], target_size[0]))
return img
image = resize_img(image, (300, 300))
image_array = img_to_array(image)
feature = []
feature.append(image_array)
img_tensor = preprocess_input(np.array(feature))
print(img_tensor.shape)
- 2、打開通道channel,構建stub,預測結果,predict_pb2進行預測請求創建
# 打開到tensorflow server的通道
with grpc.insecure_channel('127.0.0.1:8500') as channel:
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# 創建預測請求
request = predict_pb2.PredictRequest()
request.model_spec.name = 'commodity'
request.model_spec.signature_name = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
request.inputs['images'].CopyFrom(tf.contrib.util.make_tensor_proto(img_tensor, shape=[1, 300, 300, 3]))
# 進行預測
result = stub.Predict(request)
{'concat_3:0': <tf.Tensor 'concat_3:0' shape=(?, 7308, 21) dtype=float32>}
- 3、預測結果過濾并且決議,圖片標記
with tf.Session() as sess:
_res = sess.run(tf.convert_to_tensor(result.outputs['concat_3:0']))
# 3、測驗階段 進行NMS 過濾
butil = BBoxUtility(9)
outputs = butil.detection_out(_res)
return tag_picture(image_array, outputs)
- tag_picture的邏輯
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
classes_name = ['clothes', 'pants', 'shoes', 'watch', 'phone',
'audio', 'computer', 'books']
def tag_picture(img, outputs):
"""
對圖片預測物體進行畫圖顯示
:param images_data: N個測驗圖片資料
:param outputs: 每一個圖片的預測結果
:return:
"""
# 1、先獲取每張圖片6列中的結果
# 通過i獲取圖片label, location, xmin, ymin, xmax, ymax
pre_label = outputs[0][:, 0]
pre_conf = outputs[0][:, 1]
pre_xmin = outputs[0][:, 2]
pre_ymin = outputs[0][:, 3]
pre_xmax = outputs[0][:, 4]
pre_ymax = outputs[0][:, 5]
top_indices = [i for i, conf in enumerate(pre_conf) if conf >= 0.3]
top_conf = pre_conf[top_indices]
top_label_indices = pre_label[top_indices].tolist()
top_xmin = pre_xmin[top_indices]
top_ymin = pre_ymin[top_indices]
top_xmax = pre_xmax[top_indices]
top_ymax = pre_ymax[top_indices]
# print("pre_label:{}, pre_loc:{}, pre_xmin:{}, pre_ymin:{},pre_xmax:{},pre_ymax:{}".
# format(tag_label, tag_loc, tag_xmin, tag_ymin, tag_xmax, tag_ymax))
# 對于每張圖片的結果進行標記
colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
plt.imshow(img / 255.)
currentAxis = plt.gca()
for i in range(top_conf.shape[0]):
xmin = int(round(top_xmin[i] * img.shape[1]))
ymin = int(round(top_ymin[i] * img.shape[0]))
xmax = int(round(top_xmax[i] * img.shape[1]))
ymax = int(round(top_ymax[i] * img.shape[0]))
# 獲取該圖片預測概率,名稱,定義顯示顏色
score = top_conf[i]
label = int(top_label_indices[i])
label_name = classes_name[label - 1]
display_txt = '{:0.2f}, {}'.format(score, label_name)
coords = (xmin, ymin), xmax - xmin + 1, ymax - ymin + 1
color = colors[label]
# 顯示方框
currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
# 左上角顯示概率以及名稱
currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor': color, 'alpha': 0.5})
# plt.show()
image_io = BytesIO()
plt.savefig(image_io, format='png')
image_io.seek(0)
return image_io
完整代碼:
import tensorflow as tf
import grpc
from tensorflow_serving.apis import prediction_service_pb2_grpc
from tensorflow_serving.apis import predict_pb2
from tensorflow.python.saved_model import signature_constants
from keras.preprocessing.image import img_to_array
from keras.applications.imagenet_utils import preprocess_input
from utils.ssd_utils import BBoxUtility
from utils.tag_img import tag_picture
import io
from PIL import Image
import numpy as np
def make_prediction(image):
"""
"""
def resize_img(image, target_size):
img = io.BytesIO()
img.write(image)
img = Image.open(img).convert("RGB")
if target_size:
img = img.resize((target_size[1], target_size[0]))
return img
image = resize_img(image, (300, 300))
image_array = img_to_array(image)
feature = []
feature.append(image_array)
img_tensor = preprocess_input(np.array(feature))
print(img_tensor.shape)
# 打開到tensorflow server的通道
with grpc.insecure_channel('127.0.0.1:8500') as channel:
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# 創建預測請求
request = predict_pb2.PredictRequest()
request.model_spec.name = 'commodity'
request.model_spec.signature_name = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
request.inputs['images'].CopyFrom(tf.contrib.util.make_tensor_proto(img_tensor, shape=[1, 300, 300, 3]))
# 進行預測
result = stub.Predict(request)
with tf.Session() as sess:
_res = sess.run(tf.convert_to_tensor(result.outputs['concat_3:0']))
# 3、測驗階段 進行NMS 過濾
butil = BBoxUtility(9)
outputs = butil.detection_out(_res)
return tag_picture(image_array, outputs)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/294385.html
標籤:AI
