主頁 > 企業開發 > 用于視頻的Keras自定義DataGenerator;如何將正確的輸出傳遞給我的模型?

用于視頻的Keras自定義DataGenerator;如何將正確的輸出傳遞給我的模型?

2022-04-20 23:26:02 企業開發

我正在創建一個 RNN 模型來處理一定長度(10 幀)的視頻。每個視頻在其各自的檔案夾中存盤為多個影像(長度不同)。然而,在將這批幀傳遞給 RNN 模型之前,我正在使用 ResNet 特征提取器對每個幀的影像進行預處理。我正在使用自定義資料生成器來獲取包含影像的檔案夾路徑,預處理影像,然后將其傳遞給模型。

在沒有資料生成器的情況下,我一直很笨拙地這樣做,但這并不實用,因為我有一個超過 10,000 個視頻的訓練集,并且后來還希望執行資料增強。

這是我的自定義資料生成器的代碼

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels, video_paths,
                 batch_size=32, video_length=10, dim=(224,224),
                 n_channels=3, n_classes=4, IMG_SIZE = 224, MAX_SEQ_LENGTH = 10,
                 NUM_FEATURES = 2048, shuffle=True):
        'Initialization'
        
        self.list_IDs = list_IDs
        self.labels = labels
        self.video_paths = video_paths        
        self.batch_size = batch_size
        self.dim = dim
        self.video_length = video_length
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.IMG_SIZE = IMG_SIZE
        self.MAX_SEQ_LENGTH = MAX_SEQ_LENGTH
        self.NUM_FEATURES = NUM_FEATURES
        self.shuffle = shuffle
        self.on_epoch_end()
    
    def crop_center_square(frame):
        y, x = frame.shape[0:2]
        min_dim = min(y, x)
        start_x = (x // 2) - (min_dim // 2)
        start_y = (y // 2) - (min_dim // 2)
        return frame[start_y : start_y   min_dim, start_x : start_x   min_dim]
    
    def load_series(self, videopath):
        frames = []
        image_paths = [os.path.join(videopath, o) for o in os.listdir(videopath)]
        frame_num = np.linspace(0,len(image_paths)-1, num=10)   
        frame_num = frame_num.astype(int)
        resize=(self.IMG_SIZE, self.IMG_SIZE)
        # resize=(IMG_SIZE, IMG_SIZE)
        
        for ix in frame_num:
            image = Image.open(image_paths[ix])
            im_array = np.asarray(image)
            im_array = self.crop_center_square(im_array)
            # im_array = crop_center_square(im_array)
            im_array = cv2.resize(im_array, resize)
            stacked_im_array = np.stack((im_array,)*3, axis=-1)
            frames.append(stacked_im_array)
            # plt.imshow(stacked_im_array)
            # plt.show()
            
        return np.array(frames)
    
    def build_feature_extractor(self):
        feature_extractor = keras.applications.resnet_v2.ResNet152V2(
            weights="imagenet",
            include_top=False,
            pooling="avg",
            input_shape=(self.IMG_SIZE, self.IMG_SIZE, 3),
        )
        preprocess_input = keras.applications.resnet_v2.preprocess_input

        inputs = keras.Input((self.IMG_SIZE, self.IMG_SIZE, 3))
        preprocessed = preprocess_input(inputs)

        outputs = feature_extractor(preprocessed)
        return keras.Model(inputs, outputs, name="feature_extractor")


    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size: (index 1)*self.batch_size]
        
        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        
        # Generate data
        [frame_features, frame_masks], frame_labels = self._generate_X(list_IDs_temp)
      
        return [frame_features, frame_masks], frame_labels
    
    def _generate_X(self, list_IDs_temp):
        'Generates data containing batch_size videos'
        # Initialization
        frame_masks = np.zeros(shape=(self.batch_size, self.MAX_SEQ_LENGTH), dtype="bool")
        frame_features = np.zeros(shape=(self.batch_size, self.MAX_SEQ_LENGTH, self.NUM_FEATURES), dtype="float32")
        frame_labels = np.zeros(shape=(self.batch_size), dtype="int")
        feature_extractor = self.build_feature_extractor()
        tt = time.time()
        # frame_masks = np.zeros(shape=(batch_size, MAX_SEQ_LENGTH), dtype="bool")
        # frame_features = np.zeros(shape=(batch_size, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32")
        # frame_labels = np.zeros(shape=(batch_size), dtype="int")
        
        for idx, ID in enumerate(list_IDs_temp):
            videopath = self.video_paths[ID]
            # videopath = video_paths[ID]
            video_frame_label = self.labels[ID]
            # Gather all its frames and add a batch dimension.       
            frames = self.load_series(Path(videopath))
            # frames = load_series(Path(videopath))
            
            # At this point frames.shape = (10, 224, 224, 3)
            frames = frames[None, ...]
            # After this, frames.shape = (1, 10, 224, 224, 3)

            # Initialize placeholders to store the masks and features of the current video.
            temp_frame_mask = np.zeros(shape=(1, self.MAX_SEQ_LENGTH,), dtype="bool")
            # temp_frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype="bool")
            # temp_frame_mask.shape = (1,60)
            
            temp_frame_features = np.zeros(shape=(1, self.MAX_SEQ_LENGTH, self.NUM_FEATURES), dtype="float32")
            # temp_frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32")
            # temp_frame_features.shape = (1, 60, 2048)
            
            # Extract features from the frames of the current video.
            for i, batch in enumerate(frames):
                video_length = batch.shape[0]
                length = min(self.MAX_SEQ_LENGTH, video_length)
                # length = min(MAX_SEQ_LENGTH, video_length)
                for j in range(length):
                    temp_frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
                    # temp_frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
                temp_frame_mask[i, :length] = 1  # 1 = not masked, 0 = masked
                
            frame_features[idx,] = temp_frame_features.squeeze()
            frame_masks[idx,] = temp_frame_mask.squeeze()
            frame_labels[idx] = video_frame_label
        tf = time.time() - tt
        print(f'Pre-process length: {tf}')
        
        return [frame_features, frame_masks], frame_labels

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

這是RNN模型的代碼

label_processor = keras.layers.StringLookup(num_oov_indices=0, vocabulary=np.unique(train_df["view"]))

print(label_processor.get_vocabulary())

train_list_IDs = train_df.index
train_labels = train_df["view"].values
train_labels = label_processor(train_labels[..., None]).numpy()
train_video_paths = train_df['series']

training_generator = DataGenerator(train_list_IDs, train_labels, train_video_paths)

test_list_IDs = test_df.index
test_labels = test_df["view"].values
test_labels = label_processor(test_labels[..., None]).numpy()
test_video_paths = test_df['series']

testing_generator = DataGenerator(test_list_IDs, test_labels, test_video_paths)

# Utility for our sequence model.
def get_sequence_model():
    class_vocab = label_processor.get_vocabulary()

    frame_features_input = keras.Input((MAX_SEQ_LENGTH, NUM_FEATURES))
    mask_input = keras.Input((MAX_SEQ_LENGTH,), dtype="bool")

    # Refer to the following tutorial to understand the significance of using `mask`:
    # https://keras.io/api/layers/recurrent_layers/gru/
    x = keras.layers.GRU(16, return_sequences=True)(
        frame_features_input, mask=mask_input
    )
    x = keras.layers.GRU(8)(x)
    x = keras.layers.Dropout(0.4)(x)
    x = keras.layers.Dense(8, activation="relu")(x)
    output = keras.layers.Dense(len(class_vocab), activation="softmax")(x)
    
    rnn_model = keras.Model([frame_features_input, mask_input], output)

    rnn_model.compile(
        loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
    )
    return rnn_model


# Utility for running experiments.
def run_experiment():
    now = datetime.now()
    current_time = now.strftime("%d_%m_%Y_%H_%M_%S")
    filepath = os.path.join(Path('F:/RNN'), f'RNN_ResNet_Model_{current_time}')
    checkpoint = keras.callbacks.ModelCheckpoint(
        filepath, save_weights_only=True, save_best_only=True, verbose=1
    )

    seq_model = get_sequence_model()
    history = seq_model.fit(training_generator,
        epochs=EPOCHS,
        callbacks=[checkpoint],
    )
    seq_model.load_weights(filepath)
    _, accuracy = seq_model.evaluate(testing_generator)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")

    return history, accuracy, seq_model


_, accuracy, sequence_model = run_experiment()

我正在努力弄清楚如何將自定義資料生成器的結果傳遞給我的 RNN 模型?如何最好地重寫我的代碼以使用 model.fit() 或 model.fit_generator()?

先感謝您!

uj5u.com熱心網友回復:

請在您的問題中具體說明您正在努力解決的問題。你期待不同的結果,你的代碼是慢還是你得到錯誤?根據您的代碼,我發現了一些問題,并建議進行以下調整:

__getitem__()每次從生成器中檢索一批資料時,都會呼叫 DataGenerator 中的函式。在您呼叫_generate_X()的該函式中,該函式還會在每次批量生成時再次初始化預訓練的 ResNet 特征提取器feature_extractor = self.build_feature_extractor()這是非常低效的。

作為替代方案,我建議洗掉生成器類中的模型創建,而是在主筆記本中創建特征提取器并將其作為 DataGenerator 實體的引數:

在您的主檔案中:

def build_feature_extractor(self): [...]

feature_extractor = build_feature_extractor()

testing_generator = DataGenerator(test_list_IDs, test_labels, test_video_paths, feature_extractor)

對于生成器類:

class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, list_IDs, labels, video_paths, feature_extractor,
             batch_size=32, video_length=10, dim=(224,224),
             n_channels=3, n_classes=4, IMG_SIZE = 224, MAX_SEQ_LENGTH = 10,
             NUM_FEATURES = 2048, shuffle=True):
    'Initialization'
    
    self.list_IDs = list_IDs
    [...]
    self.feature_extractor = feature_extractor [...]

然后調整到這個:

temp_frame_features[i, j, :] = self.feature_extractor.predict(batch[None, j, :])

您已在您的 中正確使用了生成器.fit call,使用model.fit(training_generator, ...)將為您的模型提供創建的批次__getitem__()

uj5u.com熱心網友回復:

我得到的錯誤是

raise NotImplementedError keras

相當愚蠢的是,我忘記將以下函式放在 DataGenerator 函式中

def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

之后錯誤就消失了。

obsolete_hegemony 確實給了我一個很好的建議來優化我的代碼并分離特征提取預處理!

轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/459067.html

標籤:Python 张量流 喀拉斯

上一篇:預訓練的BERT不是LSTM層的正確形狀:值錯誤,新陣列的總大小必須保持不變

下一篇:通過將方法附加到python物件中的自身來呼叫方法n次

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • IEEE1588PTP在數字化變電站時鐘同步方面的應用

    IEEE1588ptp在數字化變電站時鐘同步方面的應用 京準電子科技官微——ahjzsz 一、電力系統時間同步基本概況 隨著對IEC 61850標準研究的不斷深入,國內外學者提出基于IEC61850通信標準體系建設數字化變電站的發展思路。數字化變電站與常規變電站的顯著區別在于程序層傳統的電流/電壓互 ......

    uj5u.com 2020-09-10 03:51:52 more
  • HTTP request smuggling CL.TE

    CL.TE 簡介 前端通過Content-Length處理請求,通過反向代理或者負載均衡將請求轉發到后端,后端Transfer-Encoding優先級較高,以TE處理請求造成安全問題。 檢測 發送如下資料包 POST / HTTP/1.1 Host: ac391f7e1e9af821806e890 ......

    uj5u.com 2020-09-10 03:52:11 more
  • 網路滲透資料大全單——漏洞庫篇

    網路滲透資料大全單——漏洞庫篇漏洞庫 NVD ——美國國家漏洞庫 →http://nvd.nist.gov/。 CERT ——美國國家應急回應中心 →https://www.us-cert.gov/ OSVDB ——開源漏洞庫 →http://osvdb.org Bugtraq ——賽門鐵克 →ht ......

    uj5u.com 2020-09-10 03:52:15 more
  • 京準講述NTP時鐘服務器應用及原理

    京準講述NTP時鐘服務器應用及原理京準講述NTP時鐘服務器應用及原理 安徽京準電子科技官微——ahjzsz 北斗授時原理 授時是指接識訓通過某種方式獲得本地時間與北斗標準時間的鐘差,然后調整本地時鐘使時差控制在一定的精度范圍內。 衛星導航系統通常由三部分組成:導航授時衛星、地面檢測校正維護系統和用戶 ......

    uj5u.com 2020-09-10 03:52:25 more
  • 利用北斗衛星系統設計NTP網路時間服務器

    利用北斗衛星系統設計NTP網路時間服務器 利用北斗衛星系統設計NTP網路時間服務器 安徽京準電子科技官微——ahjzsz 概述 NTP網路時間服務器是一款支持NTP和SNTP網路時間同步協議,高精度、大容量、高品質的高科技時鐘產品。 NTP網路時間服務器設備采用冗余架構設計,高精度時鐘直接來源于北斗 ......

    uj5u.com 2020-09-10 03:52:35 more
  • 詳細解讀電力系統各種對時方式

    詳細解讀電力系統各種對時方式 詳細解讀電力系統各種對時方式 安徽京準電子科技官微——ahjzsz,更多資料請添加VX 衛星同步時鐘是我京準公司開發研制的應用衛星授時時技術的標準時間顯示和發送的裝置,該裝置以M國全球定位系統(GLOBAL POSITIONING SYSTEM,縮寫為GPS)或者我國北 ......

    uj5u.com 2020-09-10 03:52:45 more
  • 如何保證外包團隊接入企業內網安全

    不管企業規模的大小,只要企業想省錢,那么企業的某些服務就一定會采用外包的形式,然而看似美好又經濟的策略,其實也有不好的一面。下面我通過安全的角度來聊聊使用外包團的安全隱患問題。 先看看什么服務會使用外包的,最常見的就是話務/客服這種需要大量重復性、無技術性的服務,或者是一些銷售外包、特殊的職能外包等 ......

    uj5u.com 2020-09-10 03:52:57 more
  • PHP漏洞之【整型數字型SQL注入】

    0x01 什么是SQL注入 SQL是一種注入攻擊,通過前端帶入后端資料庫進行惡意的SQL陳述句查詢。 0x02 SQL整型注入原理 SQL注入一般發生在動態網站URL地址里,當然也會發生在其它地發,如登錄框等等也會存在注入,只要是和資料庫打交道的地方都有可能存在。 如這里http://192.168. ......

    uj5u.com 2020-09-10 03:55:40 more
  • [GXYCTF2019]禁止套娃

    git泄露獲取原始碼 使用GET傳參,引數為exp 經過三層過濾執行 第一層過濾偽協議,第二層過濾帶引數的函式,第三層過濾一些函式 preg_replace('/[a-z,_]+\((?R)?\)/', NULL, $_GET['exp'] (?R)參考當前正則運算式,相當于匹配函式里的引數 因此傳遞 ......

    uj5u.com 2020-09-10 03:56:07 more
  • 等保2.0實施流程

    流程 結論 ......

    uj5u.com 2020-09-10 03:56:16 more
最新发布
  • 使用Django Rest framework搭建Blog

    在前面的Blog例子中我們使用的是GraphQL, 雖然GraphQL的使用處于上升趨勢,但是Rest API還是使用的更廣泛一些. 所以還是決定回到傳統的rest api framework上來, Django rest framework的官網上給了一個很好用的QuickStart, 我參考Qu ......

    uj5u.com 2023-04-20 08:17:54 more
  • 記錄-new Date() 我忍你很久了!

    這里給大家分享我在網上總結出來的一些知識,希望對大家有所幫助 大家平時在開發的時候有沒被new Date()折磨過?就是它的諸多怪異的設定讓你每每用的時候,都可能不小心踩坑。造成程式意外出錯,卻一下子找不到問題出處,那叫一個煩透了…… 下面,我就列舉它的“四宗罪”及應用思考 可惡的四宗罪 1. Sa ......

    uj5u.com 2023-04-20 08:17:47 more
  • 使用Vue.js實作文字跑馬燈效果

    實作文字跑馬燈效果,首先用到 substring()截取 和 setInterval計時器 clearInterval()清除計時器 效果如下: 實作代碼如下: <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <meta ......

    uj5u.com 2023-04-20 08:12:31 more
  • JavaScript 運算子

    JavaScript 運算子/運算子 在 JavaScript 中,有一些運算子可以使代碼更簡潔、易讀和高效。以下是一些常見的運算子: 1、可選鏈運算子(optional chaining operator) ?.是可選鏈運算子(optional chaining operator)。?. 可選鏈操 ......

    uj5u.com 2023-04-20 08:02:25 more
  • CSS—相對單位rem

    一、概述 rem是一個相對長度單位,它的單位長度取決于根標簽html的字體尺寸。rem即root em的意思,中文翻譯為根em。瀏覽器的文本尺寸一般默認為16px,即默認情況下: 1rem = 16px rem布局原理:根據CSS媒體查詢功能,更改根標簽的字體尺寸,實作rem單位隨螢屏尺寸的變化,如 ......

    uj5u.com 2023-04-20 08:02:21 more
  • 我的第一個NPM包:panghu-planebattle-esm(胖虎飛機大戰)使用說明

    好家伙,我的包終于開發完啦 歡迎使用胖虎的飛機大戰包!! 為你的主頁添加色彩 這是一個有趣的網頁小游戲包,使用canvas和js開發 使用ES6模塊化開發 效果圖如下: (覺得圖片太sb的可以自己改) 代碼已開源!! Git: https://gitee.com/tang-and-han-dynas ......

    uj5u.com 2023-04-20 08:01:50 more
  • 如何在 vue3 中使用 jsx/tsx?

    我們都知道,通常情況下我們使用 vue 大多都是用的 SFC(Signle File Component)單檔案組件模式,即一個組件就是一個檔案,但其實 Vue 也是支持使用 JSX 來撰寫組件的。這里不討論 SFC 和 JSX 的好壞,這個仁者見仁智者見智。本篇文章旨在帶領大家快速了解和使用 Vu ......

    uj5u.com 2023-04-20 08:01:37 more
  • 【Vue2.x原始碼系列06】計算屬性computed原理

    本章目標:計算屬性是如何實作的?計算屬性快取原理以及洋蔥模型的應用?在初始化Vue實體時,我們會給每個計算屬性都創建一個對應watcher,我們稱之為計算屬性watcher ......

    uj5u.com 2023-04-20 08:01:31 more
  • http1.1與http2.0

    一、http是什么 通俗來講,http就是計算機通過網路進行通信的規則,是一個基于請求與回應,無狀態的,應用層協議。常用于TCP/IP協議傳輸資料。目前任何終端之間任何一種通信方式都必須按Http協議進行,否則無法連接。tcp(三次握手,四次揮手)。 請求與回應:客戶端請求、服務端回應資料。 無狀態 ......

    uj5u.com 2023-04-20 08:01:10 more
  • http1.1與http2.0

    一、http是什么 通俗來講,http就是計算機通過網路進行通信的規則,是一個基于請求與回應,無狀態的,應用層協議。常用于TCP/IP協議傳輸資料。目前任何終端之間任何一種通信方式都必須按Http協議進行,否則無法連接。tcp(三次握手,四次揮手)。 請求與回應:客戶端請求、服務端回應資料。 無狀態 ......

    uj5u.com 2023-04-20 08:00:32 more