我有一個包含 1000 張紙的 excel 檔案,每張紙都包含一個資料框。為了向我的模型提供這些資料,我嘗試將其轉換為 1000 批張量,這是我的代碼:
df = pd.read_excel('file.xlsx', sheet_name=None)
file_names = list(df.keys())
columns = ['A','B','C']
features = []
labels = []
for n in file_names:
df = pd.read_excel('file.xlsx', sheet_name=n)
features.append(df[columns].to_numpy())
labels.append(df['D'].to_numpy())
Y = tf.convert_to_tensor(np.stack(labels), dtype=tf.float32)
X = tf.convert_to_tensor(np.stack(features), dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
我的代碼作業正常,但迭代它需要一個多小時。將來我將擁有 1000 多個批次的資料,因此擁有數千個 csv 檔案似乎不是一個好主意。我怎樣才能加快這個程序?
uj5u.com熱心網友回復:
您可以檢索file.xlsx一次,它將所有作業表讀入資料框字典,然后您可以從該字典中獲取作業表:
import tensorflow as tf
import pandas as pd
import numpy as np
from random import sample
### Create data
writer = pd.ExcelWriter('file.xlsx', engine='xlsxwriter')
for i in range(1000):
df = pd.DataFrame({'A': [1, i, 1, 2, 9], 'B': [3, 4, i, 1, 4], 'C': [3, 4, 3, i, 4], 'D': [1, 2, 6, 1, 4], 'E': [0, 1, 1, 0, 1]})
df.to_excel(writer, sheet_name='Sheet' str(i))
writer.save()
df = pd.read_excel('file.xlsx', sheet_name=None)
file_names = list(df.keys())
columns = ['A','B','C']
features = []
labels = []
for n in file_names:
temp_df = df[n]
features.append(temp_df[columns].to_numpy())
labels.append(temp_df['D'].to_numpy())
Y = tf.convert_to_tensor(np.stack(labels), dtype=tf.float32)
X = tf.convert_to_tensor(np.stack(features), dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
此外,您可以嘗試創建自己的自定義資料生成器并從 Excel 檔案中檢索隨機樣本,這也應該會加快速度:
df = pd.read_excel('file.xlsx', sheet_name=None)
file_names = list(df.keys())
columns = ['A','B','C']
def generator_function(samples = 64):
def generator():
for n in sample(file_names, samples):
temp_df = df[n]
x = temp_df[columns].to_numpy()
y = temp_df['D'].to_numpy()
yield x, y
return generator
gen = generator_function()
dataset = tf.data.Dataset.from_generator(
generator=gen,
output_types=(np.float32, np.int32),
output_shapes=((5, 3), (5))
)
batch_size = 16
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/347954.html
