dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
train_images = dataset['train']
test_images = dataset['test']
train_batches = (
train_images
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)
現在我想將 test_images 大小減少到 100 張影像。我期待一些代碼,如:
test_images = test_images[100]
但這會報錯:
'ParallelMapDataset' object is not subscriptable
uj5u.com熱心網友回復:
使用take()方法,您可以從目標資料集中獲取批次或專案。
如果資料集是批處理的:
test_images.take((100 // BATCH_SIZE) 1)
當您對資料集進行批處理時,它將包含批次或組。
因此,假設您對大小為 32 的資料進行批處理,test_images.take(1)將回傳 32 個元素,即單個批處理。test_images.take(2)將回傳 64 個元素等。
如果不是批處理:
test_images.take(100)
與批處理資料集不同,資料集將回傳傳入take()方法的元素數量。
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/323903.html
上一篇:只有整數、切片(`:`)、省略號(`...`)、tf.newaxis(`None`)和標量tf.int32/tf.int64張量是有效索引,得到array([219,928,
