7.8 用LSTM預測股票行情
7.8.1 匯入資料
# Tushare是一個免費、開源的python財經資料介面包,主要實作對股票等金融資料從資料采集、清洗加工 到 資料存盤的程序
import tushare as ts
cons = ts.get_apis()
#獲取滬深指數(000300)的資訊,包括交易日期(datetime)、開盤價(open)、收盤價(close),
#最高價(high)、最低價(low)、成交量(vol)、成交金額(amount)、漲跌幅(p_change)
df = ts.bar('000300', conn=cons, asset='INDEX', start_date='2010-01-01', end_date='')
本介面即將停止更新,請盡快使用Pro版介面:https://waditu.com/document/2
df = df.dropna()
df.to_csv('sh300.csv')
df.columns
Index(['code', 'open', 'close', 'high', 'low', 'vol', 'amount', 'p_change'], dtype='object')
7.8.2 資料概覽
df.describe()
| open | close | high | low | vol | amount | p_change | |
|---|---|---|---|---|---|---|---|
| count | 2751.000000 | 2751.000000 | 2751.000000 | 2751.000000 | 2.751000e+03 | 2.751000e+03 | 2751.000000 |
| mean | 3312.708859 | 3315.500174 | 3341.218680 | 3284.866252 | 1.142116e+06 | 1.474558e+11 | 0.024391 |
| std | 782.131796 | 782.340288 | 788.871807 | 773.029955 | 8.836562e+05 | 1.300980e+11 | 1.454752 |
| min | 2079.870000 | 2086.970000 | 2118.790000 | 2023.170000 | 2.190120e+05 | 2.120044e+10 | -8.750000 |
| 25% | 2611.760000 | 2613.520000 | 2632.355000 | 2591.375000 | 6.063705e+05 | 6.562710e+10 | -0.640000 |
| 50% | 3273.890000 | 3276.670000 | 3304.260000 | 3247.690000 | 8.833630e+05 | 1.065559e+11 | 0.040000 |
| 75% | 3822.735000 | 3827.870000 | 3847.855000 | 3790.325000 | 1.329321e+06 | 1.751813e+11 | 0.720000 |
| max | 5922.070000 | 5807.720000 | 5930.910000 | 5747.660000 | 6.864391e+06 | 9.494980e+11 | 6.710000 |
7.8.3 預處理資料
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
%matplotlib inline
n = 30
LR = 0.001
EPOCH = 200
batch_size=20
train_end =-600
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#通過一個序列來生成一個31*(count(*)-train_end)矩陣(用于處理時序的資料)
#其中最后一列維標簽資料,就是把當天的前n天作為引數,當天的資料作為label
def generate_data_by_n_days(series, n, index=False):
if len(series) <= n:
raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))
df = pd.DataFrame()
for i in range(n):
df['c%d' % i] = series.tolist()[i:-(n - i)]
df['y'] = series.tolist()[n:]
if index:
df.index = series.index[n:]
return df
#引數n與上相同,train_end表示的是后面多少個資料作為測驗集,
def readData(column='high', n=30, all_too=True, index=False, train_end=-500):
df = pd.read_csv("sh300.csv", index_col=0)
#以日期為索引
df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"), df.index))
#獲取每天的最高價
df_column = df[column].copy()
#拆分為訓練集和測驗集
df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]
#生成訓練資料
df_generate_train = generate_data_by_n_days(df_column_train, n, index=index)
if all_too:
return df_generate_train, df_column, df.index.tolist()
return df_generate_train
7.8.4 定義模型
class RNN(nn.Module):
def __init__(self, input_size):
super(RNN, self).__init__()
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=64,
num_layers=1,
batch_first=True
)
self.out = nn.Sequential(
nn.Linear(64, 1)
)
def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None) #None即隱層狀態用0初始化
out = self.out(r_out)
return out
class mytrainset(Dataset):
def __init__(self, data):
self.data, self.label = data[:, :-1].float(), data[:, -1].float()
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
7.8.5 訓練模型
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
# 獲取訓練資料、原始資料、索引等資訊
df, df_all, df_index = readData('high', n=n, train_end=train_end)
#可視化原高價資料
df_all = np.array(df_all.tolist())
plt.plot(df_index, df_all, label='real-data')
plt.legend(loc='upper right')
#對資料進行預處理,規范化及轉換為Tensor
df_numpy = np.array(df)
df_numpy_mean = np.mean(df_numpy)
df_numpy_std = np.std(df_numpy)
df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std
df_tensor = torch.Tensor(df_numpy)
trainset = mytrainset(df_tensor)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)
![[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-WC61e7BN-1619706519879)(output_16_0.png)]](https://img.uj5u.com/2021/05/02/240874020632481.png)
#記錄損失值,并用tensorboardx在web上展示
from tensorboardX import SummaryWriter
writer = SummaryWriter(log_dir='logs')
rnn = RNN(n).to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.MSELoss()
for step in range(EPOCH):
for tx, ty in trainloader:
tx=tx.to(device)
ty=ty.to(device)
#在第1個維度上添加一個維度為1的維度,形狀變為[batch,seq_len,input_size]
output = rnn(torch.unsqueeze(tx, dim=1)).to(device)
loss = loss_func(torch.squeeze(output), ty)
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar('sh300_loss', loss, step)
D:\sofewore\anaconda\lib\site-packages\torch\nn\modules\loss.py:432: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
7.8.6 測驗模型
generate_data_train = []
generate_data_test = []
test_index = len(df_all) + train_end
df_all_normal = (df_all - df_numpy_mean) / df_numpy_std
df_all_normal_tensor = torch.Tensor(df_all_normal)
for i in range(n, len(df_all)):
x = df_all_normal_tensor[i - n:i].to(device)
#rnn的輸入必須是3維,故需添加兩個1維的維度,最后成為[1,1,input_size]
x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)
y = rnn(x).to(device)
if i < test_index:
generate_data_train.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)
else:
generate_data_test.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)
plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')
plt.plot(df_index[train_end:], generate_data_test, label='generate_test')
plt.plot(df_index[train_end:], df_all[train_end:], label='real-data')
plt.legend()
plt.show()
![[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Ib6LvlGA-1619706519885)(output_19_0.png)]](https://img.uj5u.com/2021/05/02/240874020632482.png)
plt.clf()
plt.plot(df_index[train_end:-500], df_all[train_end:-500], label='real-data')
plt.plot(df_index[train_end:-500], generate_data_test[-600:-500], label='generate_test')
plt.legend()
plt.show()
![[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-No24KTtM-1619706519889)(output_20_0.png)]](https://img.uj5u.com/2021/05/02/240874020632483.png)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/282083.html
標籤:AI
