我正在使用 MNIST 和 tensorflow 訓練自動編碼器。
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
batch_size = 2014
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255.0, label
我希望我x是影像,我y是具有與唯一索引值(整數/浮點數)相關聯的相同影像的元組。原因是我想將該 id 傳遞給我的損失函式。我不想手動迭代并創建一個新的資料集,但如果這是唯一的解決方案,那么我將采用它。
我嘗試了多種方法,例如將 map 方法與全域變數一起使用:
lab = -1
def add_label(x, _):
global lab
lab = 1
return (x, (x, [lab]))
ds_train_original = ds_train_original.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train_original.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
# replace labels by image itself and unique id for decoder/encoder
ds_train = ds_train.map(add_label)
但是,這將回傳 0 作為所有輸入的索引而不是唯一值。
我還嘗試通過列舉資料集來手動添加標簽,但它永遠都是這樣。
當應用于它的函式在資料集上不一致時,是否有一種有效的方法來修改 TensorFlow 資料集。
uj5u.com熱心網友回復:
所以在這種情況下我會做的是只使用ref()目標張量的方法。每個張量已經有一個唯一的識別符號,這個方法允許你訪問它。
你可以試試:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
# save the refferences to your tensors
ids = np.array([y.ref() for _, y in ds_train_original])
# you can check that they are all unique
print(ids.shape, np.unique(ids).shape)
# find the 42th tensor using the deref()
t = ids[42].deref()
print(t)
# use np.where to find the index of a tensor refference
np.where( ids == t.ref())[0]
轉載請註明出處,本文鏈接:https://www.uj5u.com/qianduan/318301.html
上一篇:如何將sample_weights與3D醫療資料一起使用,而不會有model.fit(x=tf.data.Dataset)導致無法擠壓最后一個暗淡之類的錯誤
