主頁 > 企業開發 > 擬合我的模型時,Tensorflow“需要可廣播的形狀”

擬合我的模型時,Tensorflow“需要可廣播的形狀”

2022-01-09 07:55:42 企業開發

我對此很陌生,我不理解/看到任何解決此問題的方法。該錯誤特別指出:

InvalidArgumentError:  required broadcastable shapes
     [[node Equal
 (defined at c:\Users\Connor\Documents\StockIQ\venv\lib\site-packages\keras\metrics.py:3609)
]] [Op:__inference_train_function_141873]

似乎它可能與我輸入的形狀有關:

print(train_x.shape)
print(train_y.shape)
print(validation_x.shape)
print(validation_y.shape)
----
(77922, 60, 8)
(77922,)
(3860, 60, 8)
(3860,)

這是我的訓練宣告:

#save model
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))

filepath = "RNN_Final-{epoch:02d}-{val_acc:.3f}"  # unique file name that will include the epoch and the validation acc for that epoch
checkpoint = ModelCheckpoint("models/{}.model".format(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')) # saves only the best ones

history = model.fit(
    train_x, train_y,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(validation_x, validation_y),

)

對此的任何幫助都會很棒-我確定我缺少一些簡單的東西。如果有幫助,我在下面附上了完整的檔案:

import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LSTM, BatchNormalization
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
from functools import reduce
from sklearn import preprocessing
from collections import deque
import random
import time

# Prediction Variables


SEQ_LEN = 60
FUTURE_PERIOUD_PREDICT = 3
RATIO_TO_PREDICT = 'LTC-USD'
EPOCHS = 10
BATCH_SIZE = 64
NAME = f"{SEQ_LEN}-SEQ-{FUTURE_PERIOUD_PREDICT}-PRED-{int(time.time())}"


# joining 4 tables to show user potential relatioships b/w coins


main_df = pd.DataFrame() # begin empty

ratios = ["BTC-USD", "LTC-USD", "BCH-USD", "ETH-USD"]  # the 4 ratios we want to consider
for ratio in ratios:  # begin iteration
    dataset = f'crypto_data/{ratio}.csv'  # get the full path to the file.
    df = pd.read_csv(dataset, names=['time', 'low', 'high', 'open', 'close', 'volume'])  # read in specific file

    # rename volume and close to include the ticker so we can still which close/volume is which:
    df.rename(columns={"close": f"{ratio}_close", "volume": f"{ratio}_volume"}, inplace=True)

    df.set_index("time", inplace=True)  # set time as index so we can join them on this shared time
    df = df[[f"{ratio}_close", f"{ratio}_volume"]]  # ignore the other columns besides price and volume

    if len(main_df)==0:  # if the dataframe is empty
        main_df = df  # then it's just the current df
    else:  # otherwise, join this data to the main one
        main_df = main_df.join(df)

main_df.fillna(method="ffill", inplace=True)  # if there are gaps in data, use previously known values
main_df.dropna(inplace=True)
main_df.head()






# Add result and target columns

# %%
def classify(current, future):
    if float(future) > float(current):
        return 1
    else:
        return 0



main_df['future'] = main_df[f'{RATIO_TO_PREDICT}_close'].shift(-FUTURE_PERIOUD_PREDICT) #adding result column

main_df['target'] = list(map(classify, main_df[f'{RATIO_TO_PREDICT}_close'], main_df['future'])) #adding target column
main_df.head()


# seperating training and testing data
# *note the eval data is in the FUTURE of training since it's the same dataset


times = sorted(main_df.index.values)
last_5pct = times[-int(0.05*len(times))]

validation_main_df = main_df[(main_df.index >= last_5pct)]
train_main_df = main_df[(main_df.index < last_5pct)]


# normalize, sequence, and balance data


def preprocess_df(df):
    df = df.drop('future', axis=1)
# normalize data
    for col in df.columns: # normalizing all columns (except target since that's already binary)
        if col != 'target':
            df[col] = df[col].pct_change()
            df.dropna(inplace=True)
            df[col] = preprocessing.scale(df[col].values)

    df.dropna(inplace=True)

# sequencing (scaling) data
    sequential_data = []
    prev_days = deque(maxlen=SEQ_LEN) # this builds a list and then pops new records in and out to keep the same list length
    
    for i in df.values:
        prev_days.append([n for n in i[:-1]]) # take all columns excpet target
        if len(prev_days) == SEQ_LEN:
            sequential_data.append([np.array(prev_days), i[-1]])
    
    random.shuffle(sequential_data)

# balance data
    buys = []
    sells = []

    for seq, target in sequential_data: # generate two lists to compare buys vs sells
        if target == 0:
            sells.append([seq, target])
        elif target == 1:
            buys.append([seq, target])

    random.shuffle(buys)
    random.shuffle(sells)

    lower = min(len(buys), len(sells))

    buys = buys[:lower]
    sells = sells[:lower]

    sequential_data = buys sells
    random.shuffle(sequential_data)

    X = []
    Y = []

    for seq, target in sequential_data: #seperate features and labels into X & Y lists
        X.append(seq)
        Y.append(target)

    return np.array(X), np.array(Y)


train_x, train_y = preprocess_df(train_main_df)
validation_x, validation_y = preprocess_df(validation_main_df)


# Build model

#adding layers
model = Sequential()
model.add(LSTM(128, input_shape=(train_x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(LSTM(128, input_shape=(train_x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(LSTM(128, input_shape=(train_x.shape[1:]), return_sequences=True))
model.add(Dropout(0.1))
model.add(BatchNormalization())

model.add(Dense(32, activation="tanh"))
model.add(Dropout(0.2))

model.add(Dense(2, activation='softmax'))

# set optimizers
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-6)

# compile model
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=opt,
    metrics=['accuracy']
)


# Train Model


#save model
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))

filepath = "RNN_Final-{epoch:02d}-{val_acc:.3f}"  # unique file name that will include the epoch and the validation acc for that epoch
checkpoint = ModelCheckpoint("models/{}.model".format(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')) # saves only the best ones

history = model.fit(
    train_x, train_y,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(validation_x, validation_y),

)

uj5u.com熱心網友回復:

由于您的標簽是標量,因此將最后一層更改為輸出定標器:

model.add(Dense(1, activation='sigmoid'))

并更改損失函式:

loss='binary_crossentropy'

你的模型將是:

#adding layers
model = Sequential()
model.add(LSTM(128, input_shape=(train_x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(LSTM(128, return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(LSTM(128, return_sequences=False))
model.add(Dropout(0.1))
model.add(BatchNormalization())

model.add(Dense(32, activation="tanh"))
model.add(Dropout(0.2))

model.add(Dense(1, activation='sigmoid'))

uj5u.com熱心網友回復:

我發現這解決了這個問題。我的標簽格式不正確:

“如果你正在做二元交叉熵,那么你的資料集可能有 2 個類并且錯誤即將到來,因為你的標簽向量(在測驗和訓練中)具有 [0,1,0,1,1,1, 0,0,1,...]。要對二進制標簽進行一次熱編碼,可以使用以下函式:Labels = tf.one_hot(Labels, depth=2)"

我在我的火車和驗證標簽上都使用了它,它至少正在運行。不知道模型是否表現良好哈哈!

感謝大家對此的投入:)

轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/406137.html

標籤:

上一篇:如何為這個python檔案使用GPU

下一篇:從tensorflow資料集中一致地提取資料

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • IEEE1588PTP在數字化變電站時鐘同步方面的應用

    IEEE1588ptp在數字化變電站時鐘同步方面的應用 京準電子科技官微——ahjzsz 一、電力系統時間同步基本概況 隨著對IEC 61850標準研究的不斷深入,國內外學者提出基于IEC61850通信標準體系建設數字化變電站的發展思路。數字化變電站與常規變電站的顯著區別在于程序層傳統的電流/電壓互 ......

    uj5u.com 2020-09-10 03:51:52 more
  • HTTP request smuggling CL.TE

    CL.TE 簡介 前端通過Content-Length處理請求,通過反向代理或者負載均衡將請求轉發到后端,后端Transfer-Encoding優先級較高,以TE處理請求造成安全問題。 檢測 發送如下資料包 POST / HTTP/1.1 Host: ac391f7e1e9af821806e890 ......

    uj5u.com 2020-09-10 03:52:11 more
  • 網路滲透資料大全單——漏洞庫篇

    網路滲透資料大全單——漏洞庫篇漏洞庫 NVD ——美國國家漏洞庫 →http://nvd.nist.gov/。 CERT ——美國國家應急回應中心 →https://www.us-cert.gov/ OSVDB ——開源漏洞庫 →http://osvdb.org Bugtraq ——賽門鐵克 →ht ......

    uj5u.com 2020-09-10 03:52:15 more
  • 京準講述NTP時鐘服務器應用及原理

    京準講述NTP時鐘服務器應用及原理京準講述NTP時鐘服務器應用及原理 安徽京準電子科技官微——ahjzsz 北斗授時原理 授時是指接識訓通過某種方式獲得本地時間與北斗標準時間的鐘差,然后調整本地時鐘使時差控制在一定的精度范圍內。 衛星導航系統通常由三部分組成:導航授時衛星、地面檢測校正維護系統和用戶 ......

    uj5u.com 2020-09-10 03:52:25 more
  • 利用北斗衛星系統設計NTP網路時間服務器

    利用北斗衛星系統設計NTP網路時間服務器 利用北斗衛星系統設計NTP網路時間服務器 安徽京準電子科技官微——ahjzsz 概述 NTP網路時間服務器是一款支持NTP和SNTP網路時間同步協議,高精度、大容量、高品質的高科技時鐘產品。 NTP網路時間服務器設備采用冗余架構設計,高精度時鐘直接來源于北斗 ......

    uj5u.com 2020-09-10 03:52:35 more
  • 詳細解讀電力系統各種對時方式

    詳細解讀電力系統各種對時方式 詳細解讀電力系統各種對時方式 安徽京準電子科技官微——ahjzsz,更多資料請添加VX 衛星同步時鐘是我京準公司開發研制的應用衛星授時時技術的標準時間顯示和發送的裝置,該裝置以M國全球定位系統(GLOBAL POSITIONING SYSTEM,縮寫為GPS)或者我國北 ......

    uj5u.com 2020-09-10 03:52:45 more
  • 如何保證外包團隊接入企業內網安全

    不管企業規模的大小,只要企業想省錢,那么企業的某些服務就一定會采用外包的形式,然而看似美好又經濟的策略,其實也有不好的一面。下面我通過安全的角度來聊聊使用外包團的安全隱患問題。 先看看什么服務會使用外包的,最常見的就是話務/客服這種需要大量重復性、無技術性的服務,或者是一些銷售外包、特殊的職能外包等 ......

    uj5u.com 2020-09-10 03:52:57 more
  • PHP漏洞之【整型數字型SQL注入】

    0x01 什么是SQL注入 SQL是一種注入攻擊,通過前端帶入后端資料庫進行惡意的SQL陳述句查詢。 0x02 SQL整型注入原理 SQL注入一般發生在動態網站URL地址里,當然也會發生在其它地發,如登錄框等等也會存在注入,只要是和資料庫打交道的地方都有可能存在。 如這里http://192.168. ......

    uj5u.com 2020-09-10 03:55:40 more
  • [GXYCTF2019]禁止套娃

    git泄露獲取原始碼 使用GET傳參,引數為exp 經過三層過濾執行 第一層過濾偽協議,第二層過濾帶引數的函式,第三層過濾一些函式 preg_replace('/[a-z,_]+\((?R)?\)/', NULL, $_GET['exp'] (?R)參考當前正則運算式,相當于匹配函式里的引數 因此傳遞 ......

    uj5u.com 2020-09-10 03:56:07 more
  • 等保2.0實施流程

    流程 結論 ......

    uj5u.com 2020-09-10 03:56:16 more
最新发布
  • 使用Django Rest framework搭建Blog

    在前面的Blog例子中我們使用的是GraphQL, 雖然GraphQL的使用處于上升趨勢,但是Rest API還是使用的更廣泛一些. 所以還是決定回到傳統的rest api framework上來, Django rest framework的官網上給了一個很好用的QuickStart, 我參考Qu ......

    uj5u.com 2023-04-20 08:17:54 more
  • 記錄-new Date() 我忍你很久了!

    這里給大家分享我在網上總結出來的一些知識,希望對大家有所幫助 大家平時在開發的時候有沒被new Date()折磨過?就是它的諸多怪異的設定讓你每每用的時候,都可能不小心踩坑。造成程式意外出錯,卻一下子找不到問題出處,那叫一個煩透了…… 下面,我就列舉它的“四宗罪”及應用思考 可惡的四宗罪 1. Sa ......

    uj5u.com 2023-04-20 08:17:47 more
  • 使用Vue.js實作文字跑馬燈效果

    實作文字跑馬燈效果,首先用到 substring()截取 和 setInterval計時器 clearInterval()清除計時器 效果如下: 實作代碼如下: <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <meta ......

    uj5u.com 2023-04-20 08:12:31 more
  • JavaScript 運算子

    JavaScript 運算子/運算子 在 JavaScript 中,有一些運算子可以使代碼更簡潔、易讀和高效。以下是一些常見的運算子: 1、可選鏈運算子(optional chaining operator) ?.是可選鏈運算子(optional chaining operator)。?. 可選鏈操 ......

    uj5u.com 2023-04-20 08:02:25 more
  • CSS—相對單位rem

    一、概述 rem是一個相對長度單位,它的單位長度取決于根標簽html的字體尺寸。rem即root em的意思,中文翻譯為根em。瀏覽器的文本尺寸一般默認為16px,即默認情況下: 1rem = 16px rem布局原理:根據CSS媒體查詢功能,更改根標簽的字體尺寸,實作rem單位隨螢屏尺寸的變化,如 ......

    uj5u.com 2023-04-20 08:02:21 more
  • 我的第一個NPM包:panghu-planebattle-esm(胖虎飛機大戰)使用說明

    好家伙,我的包終于開發完啦 歡迎使用胖虎的飛機大戰包!! 為你的主頁添加色彩 這是一個有趣的網頁小游戲包,使用canvas和js開發 使用ES6模塊化開發 效果圖如下: (覺得圖片太sb的可以自己改) 代碼已開源!! Git: https://gitee.com/tang-and-han-dynas ......

    uj5u.com 2023-04-20 08:01:50 more
  • 如何在 vue3 中使用 jsx/tsx?

    我們都知道,通常情況下我們使用 vue 大多都是用的 SFC(Signle File Component)單檔案組件模式,即一個組件就是一個檔案,但其實 Vue 也是支持使用 JSX 來撰寫組件的。這里不討論 SFC 和 JSX 的好壞,這個仁者見仁智者見智。本篇文章旨在帶領大家快速了解和使用 Vu ......

    uj5u.com 2023-04-20 08:01:37 more
  • 【Vue2.x原始碼系列06】計算屬性computed原理

    本章目標:計算屬性是如何實作的?計算屬性快取原理以及洋蔥模型的應用?在初始化Vue實體時,我們會給每個計算屬性都創建一個對應watcher,我們稱之為計算屬性watcher ......

    uj5u.com 2023-04-20 08:01:31 more
  • http1.1與http2.0

    一、http是什么 通俗來講,http就是計算機通過網路進行通信的規則,是一個基于請求與回應,無狀態的,應用層協議。常用于TCP/IP協議傳輸資料。目前任何終端之間任何一種通信方式都必須按Http協議進行,否則無法連接。tcp(三次握手,四次揮手)。 請求與回應:客戶端請求、服務端回應資料。 無狀態 ......

    uj5u.com 2023-04-20 08:01:10 more
  • http1.1與http2.0

    一、http是什么 通俗來講,http就是計算機通過網路進行通信的規則,是一個基于請求與回應,無狀態的,應用層協議。常用于TCP/IP協議傳輸資料。目前任何終端之間任何一種通信方式都必須按Http協議進行,否則無法連接。tcp(三次握手,四次揮手)。 請求與回應:客戶端請求、服務端回應資料。 無狀態 ......

    uj5u.com 2023-04-20 08:00:32 more