如何隨機創建一個 10000 x 1000 的掩碼矩陣,以便每行有 3 個長度為 100 的連續掩碼條目?一種天真的方法如下:
import numpy as np
mask = np.ones((10000, 1000))
idx = np.random.choice(mask.shape[1] - 100, 3 * mask.shape[0]).reshape([mask.shape[0], 3])
for i, id in enumerate(idx):
for j in range(3):
for k in range(100):
mask[i][id[j] k] = 0
然而,這是極其低效的并且需要大量時間。什么是有效的實施?此外,如果連續的三個塊不重疊,那就太好了。
uj5u.com熱心網友回復:
我獲得了相當不錯的性能提升(比原始速度快 30-40 倍)
我確保零不重疊:
- 在每個樣本中有 700 個,我將 700 分成 4 個隨機整數(所以它們總和為 700)-> 我有一個的大小
- 我根據 1 的大小計算 0 的索引
def faster_than_original():
zeros_size = 100
n_zeros = 3
mask = np.ones((10000, 1000))
indices_weights = np.random.random((mask.shape[0], n_zeros 1))
number_of_ones = mask.shape[1] - zeros_size * n_zeros
ones_sizes = np.round(indices_weights[:, :n_zeros].T
* (number_of_ones / np.sum(indices_weights, axis=-1))).T.astype(np.int32)
ones_sizes[:, 1:] = zeros_size
zeros_start_indices = np.cumsum(ones_sizes, axis=-1)
for sample_idx in range(len(mask)):
for zeros_idx in zeros_start_indices[sample_idx]:
mask[sample_idx, zeros_idx: zeros_idx zeros_size] = 0
return mask
分析:
42 1 8974014.0 8974014.0 76.2 mask = original()
43 1 235235.0 235235.0 2.0 mask2 = faster_than_original()
44 1 2565371.0 2565371.0 21.8 mask3 = shaido_method()
uj5u.com熱心網友回復:
您可以為每一行創建一個索引串列,并將其直接應用于掩碼,而不是使用 2 個 for 回圈。例如:
mask = np.ones((10000, 1000))
for i in range(len(mask)):
start_indices = np.random.choice(900, 3)
indices = [idx for start_idx in start_indices for idx in range(start_idx, start_idx 100)]
mask[i][indices] = 0
要確保塊不重疊,請將其添加為索引的條件,如下所示:
mask = np.ones((10000, 1000))
for i in range(len(mask)):
cond = True
while cond:
start_indices = sorted(np.random.choice(900, 3))
cond = any([True for idx1, idx2 in zip(start_indices, start_indices[1:]) if idx1 100 >= idx2])
indices = [idx for start_idx in start_indices for idx in range(start_idx, start_idx 100)]
mask[i][indices] = 0
時間:
# original
3.42 s ± 153 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# overlaps allowed
1.41 s ± 108 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# no overlaps
2.25 s ± 199 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/366702.html
