我正在運行以下代碼: [https://pastebin.com/LK8tKZtN] 獲得的錯誤如下:
檔案“C:\Users\Admin\PycharmProjects\BugsClassfications\main2.py”,第 45 行,在 set_shapes * label.set_shape([])
ValueError: Shapes must be equal rank, but are 1 and 0
如何正確使用 set_shape 函式與 image_dataset_from_directory 一起使用?
這是我的代碼:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from albumentations import (Compose, HorizontalFlip,Rotate)
AUTOTUNE = tf.data.experimental.AUTOTUNE
def process_image(image, label, img_size):
# cast and normalize image
image = tf.image.convert_image_dtype(image, tf.float32)
# apply simple augmentations
image = tf.image.random_flip_left_right(image)
image = tf.image.resize(image,[img_size, img_size])
return image, label
transforms = Compose([
Rotate(limit=40),
HorizontalFlip()
])
def aug_fn(image, img_size):
data = {"image":image}
aug_data = transforms(**data)
aug_img = aug_data["image"]
aug_img = tf.cast(aug_img/255.0, tf.float32)
aug_img = tf.image.resize(aug_img, size=[img_size, img_size])
return aug_img
def process_data(image, label, img_size):
aug_img = tf.numpy_function(func=aug_fn, inp=[image, img_size], Tout=tf.float32)
return aug_img, label
def set_shapes(img, label, img_shape=(128,128,3)):
img.set_shape(img_shape)
label.set_shape([])
return img, label
def view_image(ds):
image, label = next(iter(ds)) # extract 1 batch from the dataset
image = image.numpy()
label = label.numpy()
fig = plt.figure(figsize=(22, 22))
for i in range(20):
ax = fig.add_subplot(4, 5, i 1, xticks=[], yticks=[])
ax.imshow(image[i].astype(dtype=np.uint8))
ax.set_title(f"Label: {label[i]}")
plt.show()
train_dir = './dataset/train'
img_size = 128
data = tf.keras.utils.image_dataset_from_directory(train_dir, image_size=(img_size, img_size))
print(data)
#augmentation
ds_alb = data.map(partial(process_data, img_size = 128), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
#resize
ds_alb = ds_alb.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(32)
print(ds_alb)
uj5u.com熱心網友回復:
如果您更改標簽的形狀,它應該可以作業:
def set_shapes(img, label, img_shape=(128,128,3)):
img.set_shape(img_shape)
label.set_shape([1,])
return img, label
但是您應該問自己為什么要明確設定資料的形狀。檢查這篇文章。
轉載請註明出處,本文鏈接:https://www.uj5u.com/caozuo/409759.html
標籤:
上一篇:圖執行模式下tensorflowtf.data資料集的拆分示例
下一篇:我怎樣才能對這個輸出進行排序?
