tf.data模塊包含:
- experimental 模塊
- Dataset 類
- FixedLengthRecordDataset 類
- TFRecordDataset 類
- TextLineDataset 類

1 # author by FH. 2 # OverView: 3 # tf.data 4 # experimental ---Modules 5 # Dataset ---class 6 # FixedLengthRecordDataset ---class 7 # TFRecordDataset ---class 8 # TextLineDataset ---class 9 import tensorflow as tf10 import numpy as np11 12 13 # 1. 使用靜態方法 tf.data.Dataset.from_tensor_slices14 # 將輸入的第一個維度切割,形成dataset15 # 2. 使用 Dataset的 make_one_shot_iterator() 實體化一個 iterator16 # 這個iterator 只能從頭到尾讀取一次,“one shot iterator”17 def test1():18 sess = tf.Session()19 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))20 dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]]))21 dataset3 = tf.data.Dataset.from_tensor_slices(22 {23 "a":np.array([1.0,2,3,4,5.0]),24 "b":np.random.uniform(size=(5,2))25 }26 )27 # 使用 Dataset的 make_one_shot_iterator() 實體化一個 iterator28 # 這個iterator 只能從頭到尾讀取一次,“one shot iterator”29 oneShotIterator1 = dataset1.make_one_shot_iterator()30 oneShotIterator2 = dataset2.make_one_shot_iterator()31 oneShotIterator3 = dataset3.make_one_shot_iterator()32 element1 = oneShotIterator1.get_next()33 element2 = oneShotIterator2.get_next()34 element3 = oneShotIterator3.get_next()35 for i in range(5):36 print(sess.run(element1))37 for i in range(3):38 print(sess.run(element2))39 for i in range(5):40 print(sess.run(element3))41 sess.close()42 43 # 1.Dataset 中的資料元素轉換,44 # map() :引數為一個函式,將dataset中的每個元素帶入獲取新的值45 # batch(): 引數為一個整數,將多個元素組合成一個batch46 def test2():47 sess = tf.Session()48 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6]))49 # map() 重新映射新的元素值50 dataset1 = dataset.map(lambda x: x * 3)51 # batch() 2個組成一個batch, 組成batch 之后size 為352 dataset2 = dataset.batch(2)53 # shuffle() 打亂dataset54 dataset3 = dataset.shuffle(buffer_size=3)55 # repeat() 將整個序列重復多次,重復4次 size 為2456 dataset4 = dataset.repeat(4)57 58 oneShotIterator1 = dataset1.make_one_shot_iterator()59 oneShotIterator2 = dataset2.make_one_shot_iterator()60 oneShotIterator3 = dataset3.make_one_shot_iterator()61 oneShotIterator4 = dataset4.make_one_shot_iterator()62 element1 = oneShotIterator1.get_next()63 element2 = oneShotIterator2.get_next()64 element3 = oneShotIterator3.get_next()65 element4 = oneShotIterator4.get_next()66 for i in range(6): # map()67 print(sess.run(element1))68 for i in range(3): # batch()69 print(sess.run(element2))70 for i in range(6): # shuffle()71 print(sess.run(element3))72 for i in range(24): # repeat()73 print(sess.run(element4))74 sess.close()75 76 # example1: 讀取圖片和相應的標簽并打亂,組成77 # batch_size=2 的資料集,重復10 epoch78 def _parse_function(imgfilename,label):79 image_value =https://www.cnblogs.com/feihu-h/p/ tf.read_file(imgfilename)80 img = tf.image.decode_image(image_value)81 img = tf.image.resize_images(img,[256,256])82 return img,label83 def example1():84 # 圖片串列85 filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg'])86 # 對應標簽87 labels = tf.constant([0,1,0,1,1,0])88 # dataset (名稱,標簽)89 dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels))90 # map 映射成圖片和標簽91 dataset = dataset.map(_parse_function)92 # shuffle ,batch , repeat93 dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10)94 return dataset95 96 if __name__ == '__main__':97 test2()View Code
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/72030.html
標籤:其他

