我已經撰寫了這個 python 函式,我相信它會移植到 numba。不幸的是,它沒有,我不確定我是否理解錯誤:
Invalid use of getiter with parameters (none).
它需要知道發電機的型別嗎?是因為它回傳可變長度的元組嗎?
from numba import njit
# @njit
def iterator(N, k):
r"""Numba implementation of an iterator over tuples of N integers,
such that sum(tuple) == k.
Args:
N (int): number of elements in the tuple
k (int): sum of the elements
Returns:
tuple(int): a tuple of N integers
"""
if N == 1:
yield (k,)
else:
for i in range(k 1):
for j in iterator(N-1, k-i):
yield (i,) j
編輯
感謝杰羅姆的提示。這是我最終寫的解決方案(我從左邊開始):
import numpy as np
from numba import njit
@njit
def next_lst(lst, i, reset=False):
r"""Computes the next list of indices given the current list
and the current index.
"""
if lst[i] == 0:
return next_lst(lst, i 1, reset=True)
else:
lst[i] -= 1
lst[i 1] = 1
if reset:
lst[0] = np.sum(lst[:i 1])
lst[1:i 1] = 0
i = 0
return lst, i
@njit
def generator(N, k):
r"""Goes through all the lists of indices recursively.
"""
lst = np.zeros(N, dtype=np.int64)
lst[0] = k
i = 0
yield lst
while lst[-1] < k:
lst, i = next_lst(lst, i)
yield lst
這給出了正確的結果,并且它是 jited !
for lst in generator(4,2):
print(lst)
[2 0 0 0]
[1 1 0 0]
[0 2 0 0]
[1 0 1 0]
[0 1 1 0]
[0 0 2 0]
[1 0 0 1]
[0 1 0 1]
[0 0 1 1]
[0 0 0 2]
uj5u.com熱心網友回復:
一個問題來自可變大小的元組輸出。實際上,元組就像 Numba 中不同型別的結構。它們與串列非常不同,而不是 Python(AFAIK,在 Python 中,元組大致就是不能變異的串列)。在 Numba 中,1 項和 2 項的元組是兩種不同的型別。它們不能統一為更通用的型別。問題是函式的回傳值必須是唯一型別。因此,Numba 拒絕在 nopython 模式下編譯該函式。在 Numba 中解決此問題的唯一方法是使用串列。
話雖如此,即使使用串列,也會報告錯誤。該檔案指出:
支持大多數遞回呼叫模式。唯一的限制是遞回被呼叫者必須有一個控制流路徑,該路徑在沒有遞回的情況下回傳。
我認為這個限制在這里沒有得到滿足,因為沒有回傳宣告。話雖如此,該函式應該隱式回傳一個生成器(其型別取決于......遞回函式本身)。還要注意,對生成器的支持是相當新的,遞回生成器沒有得到很好的支持似乎是合理的。我建議您在 Numba github 上打開一個問題,因為我不確定這是預期的行為。
請注意,在沒有遞回的情況下實作此功能可能更有效。順便說一句,如果從 Numba 函式而不是 CPython 呼叫此函式,它肯定會更快。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/455064.html
