我有一個包含 10 個類的分割圖(一個大小為 (m,n,1) 的 numpy 陣列,其中每個元素都是 1~10 之間的一個數字,指定像素所屬的一個類)。我想將它轉換為一個大小為 (m,n,10) 的陣列,其中每個通道都是該特定類元素的掩碼。我可以使用這樣的 for 回圈來做到這一點:
for i in range(10):
mask[:,:,i] = (seg_map==i)[:,:,0]
但我需要一種更快的方法來做到這一點。for 回圈花費太多時間。是否有任何內置函式可以勝過 for 回圈。
提前致謝。
uj5u.com熱心網友回復:
一種方法:
import numpy as np
np.random.seed(42)
# toy data
data = np.random.randint(0, 10, 20).reshape((5, 4, 1))
# https://stackoverflow.com/a/37323404/4001592
n_values = 10
values = data.flatten()
encoded = np.eye(n_values)[data.ravel()].reshape((5, 4, 10))
match = np.allclose(data.reshape(5, 4), encoded.argmax(-1))
print(match)
驗證輸出是否正確的一種方法是驗證 one-hot 編碼值是否與索引匹配,如下所示:
match = np.allclose(data.reshape(5, 4), encoded.argmax(-1))
print(match)
輸出
True
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/342688.html
上一篇:這兩個陣列有什么區別?
下一篇:以下兩種字串情況有什么區別
