我正在將 CSV 檔案轉換為 TFRecords 檔案,如下所示:
檔案: ./dataset/csv/ file.csv
feature_1, feture_2, output
1, 1, 1
2, 2, 2
3, 3, 3
import tensorflow as tf
import csv
import os
print(tf.__version__)
def create_csv_iterator(csv_file_path, skip_header):
with tf.io.gfile.GFile(csv_file_path) as csv_file:
reader = csv.reader(csv_file)
if skip_header: # Skip the header
next(reader)
for row in reader:
yield row
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_example(row):
"""
Returns a tensorflow.Example Protocol Buffer object.
"""
features = {}
for feature_index, feature_name in enumerate(["feature_1", "feture_2", "output"]):
feature_value = row[feature_index]
features[feature_name] = _int64_feature(int(feature_value))
return tf.train.Example(features=tf.train.Features(feature=features))
def create_tfrecords_file(input_csv_file):
"""
Creates a TFRecords file for the given input data
"""
output_tfrecord_file = input_csv_file.replace("csv", "tfrecords")
writer = tf.io.TFRecordWriter(output_tfrecord_file)
print("Creating TFRecords file at", output_tfrecord_file, "...")
for i, row in enumerate(create_csv_iterator(input_csv_file, skip_header=True)):
if len(row) == 0:
continue
example = create_example(row)
content = example.SerializeToString()
writer.write(content)
writer.close()
print("Finish Writing", output_tfrecord_file)
create_tfrecords_file("./dataset/csv/file.csv")
然后我將使用ImportExampleGen類讀取生成的 TFRecords 檔案:
import os
import absl
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
context = InteractiveContext()
example_gen = tfx.components.ImportExampleGen(input_base="./dataset/tfrecords")
context.run(example_gen, enable_cache=True)
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples'])
context.run(statistics_gen, enable_cache=True)
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
context.run(schema_gen, enable_cache=True)
檔案:./transform.py
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
print(inputs)
return inputs
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath("./transform.py"))
context.run(transform, enable_cache=True)
在preprocessing_fn函式中顯示,inputs是一個 SparseTensor 物件。我的問題是為什么?據我所知,我的資料集的樣本很密集,它們應該是 Tensor。難道我做錯了什么?
uj5u.com熱心網友回復:
對于其他可能在同一問題上苦苦掙扎的人,我找到了罪魁禍首。是SchemaGen班級。這就是我實體化它的物件的方式:
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
我不知道要求SchemaGen類不要推斷特征的形狀的用例是什么,但是我遵循的教程將其設定為False并且我剛剛復制并粘貼了相同的內容。與其他一些教程相比,我意識到這可能是我獲得SparseTensor.
因此,如果您讓SchemaGen推斷特征的形狀或加載您自己設定形狀的手工制作的模式,您將Tensor在preprocessing_fn. 但如果未設定形狀,則特征將是 的實體SparseTensor。
為了完整起見,這是固定片段:
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/515790.html
標籤:张量流tfx
