我正在使用自定義生成器來傳遞我的資料。但是我一直遇到一個錯誤,它說我已經用完了資料并在傳遞資料集時使用了repeat()。我使用的是普通生成器,因此無法使用 repeat()。有人可以幫我解決這個問題嗎
import os
import numpy as np
import cv2
def generator(idir,odir,batch_size,shuffle ):
i_list=os.listdir(idir)
o_list=os.listdir(odir)
batch_index=0
batch_size = batch_size
sample_count=len(i_list)
while True:
input_image_batch=[]
output_image_batch=[]
for i in range(batch_index * batch_size, (batch_index 1) * batch_size ):
#iterate for a batch
j=i % sample_count # cycle j value over range of available images
k=j % batch_size # cycle k value over batch size
if shuffle == True: # if shuffle select a random integer between 0 and sample_count-1 to pick as the image=label pair
m=np.random.randint(low=0, high=sample_count-1, size=None, dtype=int)
else:
m=j
path_to_in_img=os.path.join(idir,i_list[m])
path_to_out_img=os.path.join(odir,o_list[m])
print(path_to_in_img,path_to_out_img)
input_image=cv2.imread(path_to_in_img)
input_image=cv2.resize(input_image,(3200,3200))#create the target image from the input image
output_image=cv2.imread(path_to_out_img)
output_image=cv2.resize(output_image,(3200,3200))
input_image_batch.append(input_image)
output_image_batch.append(output_image)
input_val1image_array=np.array(input_image_batch)
input_val1image_array = input_val1image_array / 255.0
print (input_val1image_array)
output_val2image_array=np.array(output_image_batch)
output_val2image_array = output_val2image_array / 255.0
batch_index= batch_index 1
yield (input_val1image_array, output_val2image_array)
if batch_index * batch_size > sample_count:
break
呼叫函式
idir = r"D:\\image\\"
odir=r"D:\\image1\\"
train = generator(idir,odir,4,True)
model.compile(optimizer="adam", loss='mean_squared_error', metrics=['mean_squared_error'])
model.fit(train,validation_data = (valin_images,valout_images),batch_size= 5,epochs = 20,steps_per_epoch = int(560/batch_size))
錯誤
Epoch 1/20
186/186 [==============================] - 475s 3s/step - loss: 1779.7604 - mean_squared_error: 1779.7601 - val_loss: 28278.5488 - val_mean_squared_error: 28278.5488
Epoch 2/20
1/186 [..............................] - ETA: 1:41 - loss: 275.7113 - mean_squared_error: 275.7113WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 3720 batches). You may need to use the repeat() function when building your dataset.
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 187 batches). You may need to use the repeat() function when building your dataset.
186/186 [==============================] - 1s 235us/step - loss: 275.7113 - mean_squared_error: 275.7113
uj5u.com熱心網友回復:
如果您不使用重復(即使您使用它也有利于除錯),您需要檢查的第一件事就是您的生成器創建了多少個元素。一種快速的方法是使用類似的東西
len([g for g in generator(idir,odir,4,True)])
然后,您需要確保您的每個時期的步數乘以批量大小小于該數字。
即使使用該生成器,您也可以使用重復 - 您只需要像這樣用 tf.dataset 包裝它:
gen = lambda : generator(idir,odir,4,True)
dataset = tf.data.Dataset.from_generator(gen, output_types=(<types>), output_shapes=(<shapes>)).repeat()
你必須指定輸出型別和形狀,但它作業正常。
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/314777.html
