參考:
pytorch 矩陣維度 - 搜索結果 - 知乎
Pytorch 中的 dim操作介紹 - 大資料 - 億速云
1.如何理解dim?
- pytorch的dim和numpy的axis很類似
- 不同dim的資料長什么樣?
維度為0, 0維張量也叫標量 1
維度為1, 0維張量也叫矢量 [1,2]
維度為2, 0維張量也叫矩陣 [[1,2],[3,4]]
維度為3, 0維張量也叫矩陣陣列 [[[1,2],[3,4]],[[1,2],[3,4]]]
二維矩陣a:
a = torch.tensor([[1, 2], [3, 4]])
print(a)
tensor([[1, 2],
[3, 4]])
解釋:

三維張量b:
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
tensor([[[3, 2],
[1, 4]],
[[5, 6],
[7, 8]]])
解釋:

2.在不同dim的計算
核心:在不同dim上的計算就是對這個dim中的元素的計算,以sum為例,計算b在不同維度的sum,
- dim=0
s = torch.sum(b, dim=0)
print(s)
tensor([[ 8, 8],
[ 8, 12]])
解釋:

- dim=1
s = torch.sum(b, dim=1)
print(s)
tensor([[ 4, 6],
[12, 14]])
解釋:

- dim=2
s = torch.sum(b, dim=2)
print(s)
tensor([[ 5, 5],
[11, 15]])
在 b 的第 2 維求和,就是對標量 3 和 2, 1 和 4, 5 和 6 , 7 和 8 求和
note:在進行計算時,結果的維度發生了變換,如果不想改變,需要keepdim=True
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/389820.html
標籤:區塊鏈
上一篇:機器學習相關 解答
下一篇:【pytorch安裝】
