運行anchor生成的腳本之后報錯,請問怎么修改代碼不會報錯,謝謝
# -*- coding: utf-8 -*-
import numpy as np
import random
import argparse
import os
#引數名稱
parser = argparse.ArgumentParser(description='使用該腳本生成YOLO-V3的anchor boxes\n')
parser.add_argument('--input_annotation_txt_dir',required=True,type=str,help='檔案')
parser.add_argument('--output_anchors_txt',required=True,type=str,help='檔案')
parser.add_argument('--input_num_anchors',required=True,default=6,type=int,help='anchor個數')
parser.add_argument('--input_cfg_width',required=True,type=int,help="寬")
parser.add_argument('--input_cfg_height',required=True,type=int,help="高")
args = parser.parse_args()
'''
centroids 聚類點 尺寸是 numx2,型別是ndarray
annotation_array 其中之一的標注框
'''
def IOU(annotation_array,centroids):
#
similarities = []
#其中一個標注框
w,h = annotation_array
for centroid in centroids:
c_w,c_h = centroid
if c_w >=w and c_h >= h:#第1中情況
similarity = w*h/(c_w*c_h)
elif c_w >= w and c_h <= h:#第2中情況
similarity = w*c_h/(w*h + (c_w - w)*c_h)
elif c_w <= w and c_h >= h:#第3種情況
similarity = c_w*h/(w*h +(c_h - h)*c_w)
else:#第3種情況
similarity = (c_w*c_h)/(w*h)
similarities.append(similarity)
#將串列轉換為ndarray
return np.array(similarities,np.float32) #回傳的是一維陣列,尺寸為(num,)
'''
k_means:k均值聚類
annotations_array 所有的標注框的寬高,N個標注框,尺寸是Nx2,型別是ndarray
centroids 聚類點 尺寸是 numx2,型別是ndarray
'''
def k_means(annotations_array,centroids,eps=0.00005,iterations=2):
#
N = annotations_array.shape[0]#C=2
num = centroids.shape[0]
#損失函式
distance_sum_pre = -1
assignments_pre = -1*np.ones(N,dtype=np.int64)
#
iteration = 0
#回圈處理
while(True):
#
iteration += 1
#
distances = []
#回圈計算每一個標注框與所有的聚類點的距離(IOU)
for i in range(N):
distance = 1 - IOU(annotations_array[i],centroids)
distances.append(distance)
#串列轉換成ndarray
distances_array = np.array(distances,np.float32)#該ndarray的尺寸為 Nxnum
#找出每一個標注框到當前聚類點最近的點
assignments = np.argmin(distances_array,axis=1)#計算每一行的最小值的位置索引
#計算距離的總和,相當于k均值聚類的損失函式
distances_sum = np.sum(distances_array)
#計算新的聚類點
centroid_sums = np.zeros(centroids.shape,np.float32)
for i in range(N):
centroid_sums[assignments[i]] += annotations_array[i]#計算屬于每一聚類類別的和
for j in range(num):
centroids[j] = centroid_sums[j]/(np.sum(assignments==j))
#前后兩次的距離變化
diff = abs(distances_sum-distance_sum_pre)
#列印結果
print("iteration: {},distance: {}, diff: {}, avg_IOU: {}\n".format(iteration,distances_sum,diff,np.sum(1-distances_array)/(N*num)))
#三種情況跳出while回圈:1:回圈20000次,2:eps計算平均的距離很小 3:以上的情況
if (assignments==assignments_pre).all():
print("按照前后兩次的得到的聚類結果是否相同結束回圈\n")
break
if diff < eps:
print("按照eps結束回圈\n")
break
if iteration > iterations:
print("按照迭代次數結束回圈\n")
break
#記錄上一次迭代
distance_sum_pre = distances_sum
assignments_pre = assignments.copy()
if __name__=='__main__':
#聚類點的個數,anchor boxes的個數
num_clusters = args.input_num_anchors
#索引出檔案夾中的每一個標注檔案的名字(.txt)
names = os.listdir(args.input_annotation_txt_dir)
#標注的框的寬和高
annotations_w_h = []
for name in names:
txt_path = os.path.join(args.input_annotation_txt_dir,name)
#讀取txt檔案中的每一行
f = open(txt_path,'r')
for line in f.readlines():
line = line.rstrip('\n')
w,h = line.split(' ')[3:]#這時讀到的w,h是字串型別
#eval()函式用來將字串轉換為數值型
annotations_w_h.append((eval(w),eval(h)))
f.close()
#將串列annotations_w_h轉換為numpy中的array,尺寸是(N,2),N代表多少框
annotations_array = np.array(annotations_w_h,dtype=np.float32)
N = annotations_array.shape[0]
#對于k-means聚類,隨機初始化聚類點
random_indices = [random.randrange(N) for i in range(num_clusters)]#產生亂數
centroids = annotations_array[random_indices]
#k-means聚類
k_means(annotations_array,centroids,0.00005,2)
#對centroids按照寬排序,并寫入檔案
widths = centroids[:,0]
sorted_indices = np.argsort(widths)
anchors = centroids[sorted_indices]
#將anchor寫入檔案并保存
f_anchors = open(args.output_anchors_txt,'w')
#
for anchor in anchors:
f_anchors.write('%d,%d'%(int(anchor[0]*args.input_cfg_width),int(anchor[1]*args.input_cfg_height)))
f_anchors.write('\n')
運行完之后就會報錯
Traceback (most recent call last):
File "anchor.py", line 123, in <module>
f_anchors.write('%d,%d'%(int(anchor[0]*args.input_cfg_width),int(anchor[1]*args.input_cfg_height)))
ValueError: cannot convert float NaN to integer
求指點,該怎么改才不會報錯
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/16602.html
上一篇:資訊安全實踐一之密碼與隱藏技術1【凱撒密碼&仿射密碼】
下一篇:python pandas to_excle在現有sheet中追加資料時,運行結果為什么是在新增sheet中追加的
