主頁 > 資料庫 > 適用于稀疏向量、獨熱編碼資料的損失函式回顧和PyTorch實作

適用于稀疏向量、獨熱編碼資料的損失函式回顧和PyTorch實作

2020-10-01 09:25:36 資料庫

在稀疏的、獨熱編碼編碼資料上構建自動編碼器

自1986年[1]問世以來,在過去的30年里,通用自動編碼器神經網路已經滲透到現代機器學習的大多數主要領域的研究中,在嵌入復雜資料方面,自動編碼器已經被證明是非常有效的,它提供了簡單的方法來將復雜的非線性依賴編碼為平凡的向量表示,但是,盡管它們的有效性已經在許多方面得到了證明,但它們在重現稀疏資料方面常常存在不足,特別是當列像一個熱編碼那樣相互關聯時,

在本文中,我將簡要地討論一種熱編碼(OHE)資料和一般的自動編碼器,然后,我將介紹使用在一個熱門編碼資料上受過訓練的自動編碼器所帶來的問題的用例,最后,我將深入討論稀疏OHE資料重構的問題,然后介紹我發現在這些條件下運行良好的3個損失函式:

  • CosineEmbeddingLoss
  • Sorenson-Dice Coefficient Loss
  • Multi-Task Learning Losses of Individual OHE Components

-解決了上述挑戰,包括在PyTorch中實作它們的代碼,

熱編碼資料

熱編碼資料是一種最簡單的,但在一般機器學習場景中經常被誤解的資料預處理技術,該程序將具有“N”不同類別的分類資料二值化為二進制0和1的N列,第N個類別中出現1表示該觀察屬于該類別,這個程序在Python中很簡單,使用Scikit-Learn OneHotEncoder模塊:

from sklearn.preprocessing import OneHotEncoder
import numpy as np# Instantiate a column of 10 random integers from 5 classes
x = np.random.randint(5, size=10).reshape(-1,1)print(x)
>>> [[2][3][2][2][1][1][4][1][0][4]]# Instantiate OHE() + Fit/Transform the data
ohe_encoder = OneHotEncoder(categories="auto")
encoded = ohe_encoder.fit_transform(x).todense()print(encoded)
>>> matrix([[0., 1., 0., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 0., 0., 1.]])print(list(ohe_encoder.get_feature_names()))
>>> ["x0_0", "x0_1", "x0_2", "x0_3", "x0_4"]

但是,盡管這個技巧很簡單,但如果不小心,它可能很快就會失效,它可以很容易地為資料添加多余的復雜性,并改變資料上某些分類方法的有效性,例如,轉換成OHE向量的列現在是相互依賴的,這種互動使得在某些型別的分類器中有效地表示資料方面變得困難,例如,如果您有一個包含15個不同類別的列,那么就需要一個深度為15的決策樹來處理該熱編碼列中的if-then模式(當然樹形模型的資料處理是不需要進行獨熱編碼的,這里只是舉例),類似地,由于列是相互依賴的,如果使用bagging (Bootstrap聚合)的分類策略并執行特性采樣,則可能會完全錯過單次編碼的列,或者只考慮它的部分組件類,

Autoencoders

自動編碼器是一種無監督的神經網路,其作業是將資料嵌入到一種有效的壓縮格式,它利用編碼和解碼程序將資料編碼為更小的格式,然后再將更小的格式解碼為原始的輸入表示,利用模型重構(譯碼)與原始資料之間的損失對模型進行訓練,

實際上,用代碼表示這個網路也很容易,我們從兩個函式開始:編碼器模型和解碼器模型,這兩個“模型”都被封裝在一個叫做Network的類中,它將包含我們的培訓和評估的整個系統,最后,我們定義了一個Forward函式,PyTorch將它用作進入網路的入口,用于包裝資料的編碼和解碼,

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Network(nn.Module):
   def __init__(self, input_shape: int):
      super().__init__()
      self.encode1 = nn.Linear(input_shape, 500)
      self.encode2 = nn.Linear(500, 250)
      self.encode3 = nn.Linear(250, 50)
      
      self.decode1 = nn.Linear(50, 250)
      self.decode2 = nn.Linear(250, 500)
      self.decode3 = nn.Linear(500, input_shape)   def encode(self, x: torch.Tensor):
      x = F.relu(self.encode1(x))
      x = F.relu(self.encode2(x))
      x = F.relu(self.encode3(x))
      return x   def decode(self, x: torch.Tensor):
      x = F.relu(self.decode1(x))
      x = F.relu(self.decode2(x))
      x = F.relu(self.decode3(x))
      return x   def forward(self, x: torch.Tensor):
      x = encode(x)
      x = decode(x)
      return x
def train_model(data: pd.DataFrame):
   net = Network()
   optimizer = optim.Adagrad(net.parameters(), lr=1e-3, weight_decay=1e-4)
   losses = []   for epoch in range(250):
     for batch in get_batches(data)
        net.zero_grad()
        
        # Pass batch through 
        output = net(batch)
        
        # Get Loss + Backprop
        loss = loss_fn(output, batch).sum() # 
        losses.append(loss)
        loss.backward()
        optimizer.step()
     return net, losses

正如我們在上面看到的,我們有一個編碼函式,它從輸入資料的形狀開始,然后隨著它向下傳播到形狀為50而降低它的維數,從那里,解碼層接受嵌入,然后將其擴展回原來的形狀,在訓練中,我們從譯碼器中取出重構的結果,并取出重構與原始輸入的損失,

損失函式的問題

所以現在我們已經討論了自動編碼器的結構和一個熱編碼程序,我們終于可以討論與使用一個熱編碼在自動編碼器相關的問題,以及如何解決這個問題,當一個自動編碼器比較重建到原始輸入資料,必須有一些估值之間的距離提出重建和真實的價值,通常,在輸出值被認為互不相干的情況下,將使用交叉熵損失或MSE損失,但在我們的一個熱編碼的情況下,有幾個問題,使系統更復雜:

  • 一列出現1意味著對應的OHE列必須有一個0,即列不是不相交的
  • OHE向量輸入的稀疏性會導致系統選擇簡單地將大多數列回傳0以減少誤差

這些問題結合起來導致上述兩個損失(MSE,交叉熵)在重構稀疏OHE資料時無效,下面我將介紹三種損失,它們提供了一個解決方案,或上述問題,并在PyTorch實作它們的代碼:

余弦嵌入損失

余弦距離是一種經典的向量距離度量,常用于NLP問題中比較字包表示,通過求兩個向量之間的余弦來計算距離,計算方法為:

由于該方法能夠考慮到各列中二進制值的偏差來評估兩個向量之間的距離,因此在稀疏嵌入重構中,該方法能夠很好地量化誤差,這種損失是迄今為止在PyTorch中最容易實作的,因為它在 Torch.nn.CosineEmbeddingLoss中有一個預先構建的解決方案

loss_function = torch.nn.CosineEmbeddingLoss(reduction='none')# . . . Then during training . . . loss = loss_function(reconstructed, input_data).sum()
loss.backward()

Dice Loss

Dice Loss是一個實作S?rensen-Dice系數[2],這是非常受歡迎的計算機視覺領域的分割任務,簡單地說,它是兩個集合之間重疊的度量,并且與兩個向量之間的Jaccard距離有關,骰子系數對向量中列值的差異高度敏感,利用這種敏感性有效地區分影像中像素的邊緣,因此在影像分割中非常流行,Dice Loss為:

PyTorch沒有內部實作的Dice Loss,但是在Kaggle上可以在其丟失函式庫- Keras & PyTorch[3]中找到一個很好的實作:

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid acitvation
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/
               (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

不同OHE列的單個損失函式

最后,您可以將每個熱編碼列視為其自身的分類問題,并承擔每個分類的損失,這是一個多任務學習問題的用例,其中autoencoder正在解決重構輸入向量的各個分量的問題,當你有幾個/所有的列在你的輸入資料時,這個作業最好,例如,如果您有一個編碼列,前7列是7個類別:您可以將其視為一個多類分類問題,并將損失作為子問題的交叉熵損失,然后,您可以將子問題的損失合并在一起,并將其作為整個批的損失向后傳遞,

下面您將看到這個程序的示例,其中示例有三個熱編碼的列,每個列有50個類別,

from torch.nn.modules import _Loss
from torch import argmaxclass CustomLoss(_Loss):
  def __init__(self):
    super(CustomLoss, self).__init__()  def forward(self, input, target):
    """ loss function called at runtime """
   
    # Class 1 - Indices [0:50]
    class_1_loss = F.nll_loss(
        F.log_softmax(input[:, 0:50], dim=1), 
        argmax(target[:, 0:50])
    )    # Class 2 - Indices [50:100]
    class_2_loss = F.nll_loss(
        F.log_softmax(input[:, 50:100], dim=1), 
        argmax(target[:, 50:100])
    )    # Class 3 - Indices [100:150]
    class_3_loss = F.nll_loss(
        F.log_softmax(input[:, 100:150], dim=1), 
        argmax(target[:, 100:150])
    )    return class_1_loss + class_2_loss + class_3_loss

在上面的代碼中,您可以看到重構輸出的子集是如何承受個體損失的,然后在最后將其合并為一個總和,這里我們使用了一個負對數似然損失(nll_loss),它是一個很好的損失函式用于多類分類方案,并與交叉熵損失有關,

總結

在本文中,我們瀏覽了一個獨熱編碼分類變數的概念,以及自動編碼器的一般結構和目標,我們討論了一個熱編碼向量的缺點,以及在嘗試訓練稀疏的、一個獨熱編碼資料的自編碼器模型時的主要問題,最后,我們討論了解決稀疏一熱編碼問題的3個損失函式,訓練這些網路并沒有更好或更壞的損失,在我所介紹的功能中,沒有辦法知道哪個是適合您的用例的,除非您嘗試它們!

下面我提供了一些深入討論上述主題的資源,以及一些我提供的關于丟失函式的資源,

資源

  1. D.E. Rumelhart, G.E. Hinton, and R.J. Williams, “Learning internal representations by error propagation.” Parallel Distributed Processing. Vol 1: Foundations. MIT Press, Cambridge, MA, 1986.
  2. S?rensen, T. (1948). “A method of establishing groups of equal amplitude in plant sociology based on similarity of species and its application to analyses of the vegetation on Danish commons”. Kongelige Danske Videnskabernes Selskab. 5 (4): 1–34. *AND* Dice, Lee R. (1945). “Measures of the Amount of Ecologic Association Between Species”. Ecology. 26 (3): 297–302.
  3. Kaggle’s Loss Function Library: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch

作者:Nick Hespe

deephub翻譯組

轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/144727.html

標籤:其他

上一篇:玩轉華為云,ModelArts深度學習建模最全搭建手冊

下一篇:NLP頂會論文寫作技巧個人總結!

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • GPU虛擬機創建時間深度優化

    **?桔妹導讀:**GPU虛擬機實體創建速度慢是公有云面臨的普遍問題,由于通常情況下創建虛擬機屬于低頻操作而未引起業界的重視,實際生產中還是存在對GPU實體創建時間有苛刻要求的業務場景。本文將介紹滴滴云在解決該問題時的思路、方法、并展示最終的優化成果。 從公有云服務商那里購買過虛擬主機的資深用戶,一 ......

    uj5u.com 2020-09-10 06:09:13 more
  • 可編程網卡芯片在滴滴云網路的應用實踐

    **?桔妹導讀:**隨著云規模不斷擴大以及業務層面對延遲、帶寬的要求越來越高,采用DPDK 加速網路報文處理的方式在橫向縱向擴展都出現了局限性。可編程芯片成為業界熱點。本文主要講述了可編程網卡芯片在滴滴云網路中的應用實踐,遇到的問題、帶來的收益以及開源社區貢獻。 #1. 資料中心面臨的問題 隨著滴滴 ......

    uj5u.com 2020-09-10 06:10:21 more
  • 滴滴資料通道服務演進之路

    **?桔妹導讀:**滴滴資料通道引擎承載著全公司的資料同步,為下游實時和離線場景提供了必不可少的源資料。隨著任務量的不斷增加,資料通道的整體架構也隨之發生改變。本文介紹了滴滴資料通道的發展歷程,遇到的問題以及今后的規劃。 #1. 背景 資料,對于任何一家互聯網公司來說都是非常重要的資產,公司的大資料 ......

    uj5u.com 2020-09-10 06:11:05 more
  • 滴滴AI Labs斬獲國際機器翻譯大賽中譯英方向世界第三

    **桔妹導讀:**深耕人工智能領域,致力于探索AI讓出行更美好的滴滴AI Labs再次斬獲國際大獎,這次獲獎的專案是什么呢?一起來看看詳細報道吧! 近日,由國際計算語言學協會ACL(The Association for Computational Linguistics)舉辦的世界最具影響力的機器 ......

    uj5u.com 2020-09-10 06:11:29 more
  • MPP (Massively Parallel Processing)大規模并行處理

    1、什么是mpp? MPP (Massively Parallel Processing),即大規模并行處理,在資料庫非共享集群中,每個節點都有獨立的磁盤存盤系統和記憶體系統,業務資料根據資料庫模型和應用特點劃分到各個節點上,每臺資料節點通過專用網路或者商業通用網路互相連接,彼此協同計算,作為整體提供 ......

    uj5u.com 2020-09-10 06:11:41 more
  • 滴滴資料倉庫指標體系建設實踐

    **桔妹導讀:**指標體系是什么?如何使用OSM模型和AARRR模型搭建指標體系?如何統一流程、規范化、工具化管理指標體系?本文會對建設的方法論結合滴滴資料指標體系建設實踐進行解答分析。 #1. 什么是指標體系 ##1.1 指標體系定義 指標體系是將零散單點的具有相互聯系的指標,系統化的組織起來,通 ......

    uj5u.com 2020-09-10 06:12:52 more
  • 單表千萬行資料庫 LIKE 搜索優化手記

    我們經常在資料庫中使用 LIKE 運算子來完成對資料的模糊搜索,LIKE 運算子用于在 WHERE 子句中搜索列中的指定模式。 如果需要查找客戶表中所有姓氏是“張”的資料,可以使用下面的 SQL 陳述句: SELECT * FROM Customer WHERE Name LIKE '張%' 如果需要 ......

    uj5u.com 2020-09-10 06:13:25 more
  • 滴滴Ceph分布式存盤系統優化之鎖優化

    **桔妹導讀:**Ceph是國際知名的開源分布式存盤系統,在工業界和學術界都有著重要的影響。Ceph的架構和演算法設計發表在國際系統領域頂級會議OSDI、SOSP、SC等上。Ceph社區得到Red Hat、SUSE、Intel等大公司的大力支持。Ceph是國際云計算領域應用最廣泛的開源分布式存盤系統, ......

    uj5u.com 2020-09-10 06:14:51 more
  • es~通過ElasticsearchTemplate進行聚合~嵌套聚合

    之前寫過《es~通過ElasticsearchTemplate進行聚合操作》的文章,這一次主要寫一個嵌套的聚合,例如先對sex集合,再對desc聚合,最后再對age求和,共三層嵌套。 Aggregations的部分特性類似于SQL語言中的group by,avg,sum等函式,Aggregation ......

    uj5u.com 2020-09-10 06:14:59 more
  • 爬蟲日志監控 -- Elastc Stack(ELK)部署

    傻瓜式部署,只需替換IP與用戶 導讀: 現ELK四大組件分別為:Elasticsearch(核心)、logstash(處理)、filebeat(采集)、kibana(可視化) 下載均在https://www.elastic.co/cn/downloads/下tar包,各組件版本最好一致,配合fdm會 ......

    uj5u.com 2020-09-10 06:15:05 more
最新发布
  • day02-2-商鋪查詢快取

    功能02-商鋪查詢快取 3.商鋪詳情快取查詢 3.1什么是快取? 快取就是資料交換的緩沖區(稱作Cache),是存盤資料的臨時地方,一般讀寫性能較高。 快取的作用: 降低后端負載 提高讀寫效率,降低回應時間 快取的成本: 資料一致性成本 代碼維護成本 運維成本 3.2需求說明 如下,當我們點擊商店詳 ......

    uj5u.com 2023-04-20 08:33:24 more
  • MySQL中binlog備份腳本分享

    關于MySQL的二進制日志(binlog),我們都知道二進制日志(binlog)非常重要,尤其當你需要point to point災難恢復的時侯,所以我們要對其進行備份。關于二進制日志(binlog)的備份,可以基于flush logs方式先切換binlog,然后拷貝&壓縮到到遠程服務器或本地服務器 ......

    uj5u.com 2023-04-20 08:28:06 more
  • day02-短信登錄

    功能實作02 2.功能01-短信登錄 2.1基于Session實作登錄 2.1.1思路分析 2.1.2代碼實作 2.1.2.1發送短信驗證碼 發送短信驗證碼: 發送驗證碼的介面為:http://127.0.0.1:8080/api/user/code?phone=xxxxx<手機號> 請求方式:PO ......

    uj5u.com 2023-04-20 08:27:27 more
  • 快取與資料庫雙寫一致性幾種策略分析

    本文將對幾種快取與資料庫保證資料一致性的使用方式進行分析。為保證高并發性能,以下分析場景不考慮執行的原子性及加鎖等強一致性要求的場景,僅追求最終一致性。 ......

    uj5u.com 2023-04-20 08:26:48 more
  • sql陳述句優化

    問題查找及措施 問題查找 需要找到具體的代碼,對其進行一對一優化,而非一直把關注點放在服務器和sql平臺 降低簡化每個事務中處理的問題,盡量不要讓一個事務拖太長的時間 例如檔案上傳時,應將檔案上傳這一步放在事務外面 微軟建議 4.啟動sql定時執行計劃 怎么啟動sqlserver代理服務-百度經驗 ......

    uj5u.com 2023-04-20 08:26:35 more
  • 云時代,MySQL到ClickHouse資料同步產品對比推薦

    ClickHouse 在執行分析查詢時的速度優勢很好的彌補了MySQL的不足,但是對于很多開發者和DBA來說,如何將MySQL穩定、高效、簡單的同步到 ClickHouse 卻很困難。本文對比了 NineData、MaterializeMySQL(ClickHouse自帶)、Bifrost 三款產品... ......

    uj5u.com 2023-04-20 08:26:29 more
  • sql陳述句優化

    問題查找及措施 問題查找 需要找到具體的代碼,對其進行一對一優化,而非一直把關注點放在服務器和sql平臺 降低簡化每個事務中處理的問題,盡量不要讓一個事務拖太長的時間 例如檔案上傳時,應將檔案上傳這一步放在事務外面 微軟建議 4.啟動sql定時執行計劃 怎么啟動sqlserver代理服務-百度經驗 ......

    uj5u.com 2023-04-20 08:25:13 more
  • Redis 報”OutOfDirectMemoryError“(堆外記憶體溢位)

    Redis 報錯“OutOfDirectMemoryError(堆外記憶體溢位) ”問題如下: 一、報錯資訊: 使用 Redis 的業務介面 ,產生 OutOfDirectMemoryError(堆外記憶體溢位),如圖: 格式化后的報錯資訊: { "timestamp": "2023-04-17 22: ......

    uj5u.com 2023-04-20 08:24:54 more
  • day02-2-商鋪查詢快取

    功能02-商鋪查詢快取 3.商鋪詳情快取查詢 3.1什么是快取? 快取就是資料交換的緩沖區(稱作Cache),是存盤資料的臨時地方,一般讀寫性能較高。 快取的作用: 降低后端負載 提高讀寫效率,降低回應時間 快取的成本: 資料一致性成本 代碼維護成本 運維成本 3.2需求說明 如下,當我們點擊商店詳 ......

    uj5u.com 2023-04-20 08:24:03 more
  • day02-短信登錄

    功能實作02 2.功能01-短信登錄 2.1基于Session實作登錄 2.1.1思路分析 2.1.2代碼實作 2.1.2.1發送短信驗證碼 發送短信驗證碼: 發送驗證碼的介面為:http://127.0.0.1:8080/api/user/code?phone=xxxxx<手機號> 請求方式:PO ......

    uj5u.com 2023-04-20 08:23:11 more