我想從 BatchDataSet 中檢索前 N 個專案。我嘗試了許多不同的方法來做到這一點,并且在重新評估時它們都會檢索到不同的專案。但是我想檢索 N 個實際專案,而不是一個將繼續檢索新專案的迭代器。
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
"Images",
validation_split=0.2,
seed=123,
subset="training")
# Attempt to retrieve 9 items
test_ds = ds.take(9)
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is
# that it should show the same images as above.
#
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
uj5u.com熱心網友回復:
迭代 atf.data.Dataset每次都會觸發洗牌。您可以設定shuffle為False獲得確定性結果:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(64, 64),
batch_size=1,
shuffle=False)
# Attempt to retrieve 9 items
test_ds = ds.take(9)
class_names = ['a', 'b', 'c', 'd', 'e']
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[0, ...].numpy().astype("uint8"))
plt.title(class_names[labels.numpy()[0]])
plt.axis("off")
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[0, ...].numpy().astype("uint8"))
plt.title(class_names[labels.numpy()[0]])
plt.axis("off")

如果您對其他資料樣本感興趣,您可以使用方法tf.data.Dataset.skip和tf.data.Dataset.take。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/531227.html
上一篇:如何使用Keras指定輸入層
