對資料集的shuffle處理需要設定相應的buffer_size引數,相當于需要將相應數目的樣本讀入記憶體,且這部分記憶體會在訓練程序中一直保持占用,完全的shuffle需要將整個資料集讀入記憶體,這在大規模資料集的情況下是不現實的,故需要結合設備記憶體以及Batch大小將TFRecord檔案隨機劃分為多個子檔案,再對資料集做local shuffle(即設定相對較小的buffer_size,不小于單個子檔案的樣本數),
Shuffle和劃分
下文以一個例外檢測資料集(正負樣本不平衡)為例,在生成第一批TFRecord時,我將正負樣本分別寫入單獨的TFrecord檔案以備后續在對正負樣本有不同處理策略的情況下無需再決議example_proto,比如在以下代碼中,我對正負樣本有不同的驗證集比例,并將他們寫入不同的驗證集檔案,
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm
# TFRecord劃分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
for example_proto in raw_normal_dataset:
# 劃分訓練集和測驗集
if np.random.random() > 0.99: # 正樣本測驗集的比例
normal_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
for example_proto in raw_anomaly_dataset:
# 劃分訓練集和測驗集
if np.random.random() > 0.7: # 負樣本測驗集的比例
anomaly_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
normal_val_writer.close()
anomaly_val_writer.close()
for train_writer in train_writer_list:
train_writer.close()
讀取
raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
parsed_train_dataset = raw_train_dataset.map(map_func=map_func)
raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/500431.html
標籤:其他
上一篇:關于堆疊遷移的那些事兒
