我正在嘗試通過索引矩陣訪問 pytorch 張量,我最近發現了這段代碼,但我找不到它不起作用的原因。
下面的代碼分為兩部分。前半部分證明有效,而后半部分出現錯誤。我看不出原因。有人可以對此有所了解嗎?
import torch
import numpy as np
a = torch.rand(32, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # WORKS for a torch.tensor of size M >= 32. It doesn't work otherwise.
a = torch.rand(16, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # IndexError: too many indices for tensor of dimension 2
如果我改變a = np.random.rand(16, 16)它也可以。
uj5u.com熱心網友回復:
首先,讓我讓您快速了解使用 numpy 陣列和另一個張量索引張量的想法。
示例:這是我們要索引的目標張量
numpy_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # numpy array
tensor_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # 2D tensor
t = torch.tensor([[1, 2, 3, 4], # targeted tensor
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32]])
numpy_result = t[numpy_indices]
tensor_result = t[tensor_indices]
使用 2D numpy 陣列進行索引:索引被讀取為對 (x,y) tensor[row,column] 例如
t[0,0], t[1,1], t[2,2], and t[7,3]。print(numpy_result) # tensor([ 1, 6, 11, 32])使用 2D 張量進行索引:以逐行方式遍歷索引張量,每個值都是目標張量中一行的索引。例如看下面的例子,索引后
[ [t[0],t[1],t[2],[7]] , [[0],[1],[2],[3]] ]的新形狀是.tensor_result(tensor_indices.shape[0],tensor_indices.shape[1],t.shape[1])=(2,4,4)print(tensor_result) # tensor([[[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12], # [29, 30, 31, 32]], # [[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12], # [ 13, 14, 15, 16]]])
如果您嘗試在 中添加第三行numpy_indices,您將得到相同的錯誤,因為索引將由 3D 表示,例如 (0,0,0)...(7,3,3)。
indices = np.array([[0, 1, 2, 7],
[0, 1, 2, 3],
[0, 1, 2, 3]])
print(numpy_result) # IndexError: too many indices for tensor of dimension 2
然而,這不是張量索引的情況,形狀會更大(3,4,4)。
最后,如您所見,兩種索引的輸出完全不同。為了解決您的問題,您可以使用
xx = torch.tensor(xx).long() # convert a numpy array to a tensor
在高級索引( numpy_indices > 3 的行)的情況下會發生什么,因為您的情況仍然模棱兩可且未解決,您可以檢查1、2、3。
uj5u.com熱心網友回復:
對于任何來尋找答案的人:它看起來像是 pyTorch 中的一個錯誤。
使用 numpy 陣列的索引沒有很好的定義,它只有在使用張量索引張量時才有效。因此,在我的示例代碼中,這完美無缺:
a = torch.rand(M, N)
m, n = a.shape
xx, yy = torch.meshgrid(torch.arange(m), torch.arange(m), indexing='xy')
result = a[xx] # WORKS
我做了一個要點來檢查它,它可以在這里找到
轉載請註明出處,本文鏈接:https://www.uj5u.com/caozuo/432275.html
上一篇:即使我清理了快取并重新打開,瀏覽器DevTools也不允許我在某些行上設定斷點
下一篇:Laravel動態驗證
