我首先說我對這個編程分支完全陌生,但我認為 scipy 優化可能是解決方案。
我需要找到在函式中回傳最高結果的引數,但前提是結果符合條件。
該函式是如此之長,并且需要 40 多個引數,因此暴力破解它們是不可能的而且太慢了,該函式在輸出中回傳 2 個相同長度的陣列。
constant = [1,2,3,4,4,3,5,6,7,8]
def fun(constant, length, period, multiplier, factor, ... ):
do long and complicated calculations
return array1, array2
現在,我需要的是找到回傳最高array1[-1] 值的引數if max(array2) < 40(例如),然后列印它們。
所有引數(長度、周期、乘數、因子)的作業范圍為 2 到 200。相反constant,顯然不應該受到優化的影響。
我嘗試按范圍回圈所有引數并一次執行計算,但它非常低效、復雜,而且我認為沒有給出最好的結果。
如何執行這種型別的引數優化?
uj5u.com熱心網友回復:
如果你想從頭開始構建,一個簡單的完全隨機的“足夠好”的求解器可能看起來像這樣。
求解器是第一個功能,其余的是您的(用戶)功能。
你需要
- 您的目標冗長而復雜的功能
- 回傳給定生成結果的分數的函式(如果結果無效,則回傳零)
- 為每個要嘗試的引數回傳值的函式字典。
import random
import time
def find_solution(
target_function,
score_solution,
param_generators,
max_iterations=10_000_000,
max_time=60,
):
best_solution = None
best_score = 0
start_time = time.time()
for i in range(max_iterations):
params = {param: gen() for param, gen in param_generators.items()}
solution = target_function(**params)
score = score_solution(solution)
if score > best_score:
best_score = score
best_solution = (params, solution)
print(f"{i} / New best solution: {best_solution}")
if time.time() - start_time > max_time:
print(f"{i} / Time limit reached")
break
return (best_solution, best_score)
def fun(constant, length, period, multiplier, factor):
a = constant * length * period * multiplier * factor
b = length * period
return (a, [b])
def sol_scorer(sol):
if max(sol[1]) < 40: # Invalid; return 0
return 0
return sol[0]
def main():
constants = [1, 2, 3, 4, 4, 3, 5, 6, 7, 8]
param_generators = {
"constant": lambda: random.choice(constants),
"length": lambda: random.randint(1, 100),
"period": lambda: random.randint(1, 100),
"multiplier": lambda: random.randint(1, 100),
"factor": lambda: random.randint(1, 100),
}
res = find_solution(
fun,
sol_scorer,
param_generators,
max_iterations=10_000_000,
max_time=10,
)
print(res)
if __name__ == "__main__":
main()
在我的機器上,這會列印出來,例如
227838 / New best solution: ({'constant': 8, 'length': 98, 'period': 93, 'multiplier': 96, 'factor': 99}, (692955648, [9114]))
1085159 / New best solution: ({'constant': 8, 'length': 98, 'period': 99, 'multiplier': 91, 'factor': 100}, (706305600, [9702]))
1447216 / New best solution: ({'constant': 8, 'length': 99, 'period': 97, 'multiplier': 97, 'factor': 97}, (722837016, [9603]))
2325989 / Time limit reached
(({'constant': 8, 'length': 99, 'period': 97, 'multiplier': 97, 'factor': 97}, (722837016, [9603])), 722837016)
使用順序組合
添加一個總是嘗試一些組合的選項并不是更多的代碼;見下文。
import random
import time
from itertools import product
from typing import Any, Callable, Optional, Iterable
def find_solution(
target_function: Callable[..., Any],
score_solution: Callable[[Any], float],
param_generators: dict[str, Callable[[], Any]],
sequential_combination_generator: Optional[Iterable[dict]] = None,
max_iterations: int = 10_000_000,
max_time: float = 60.0,
) -> tuple[Any, float]:
best_solution = None
best_score = 0.0
start_time = time.time()
if sequential_combination_generator is None:
sequential_combination_generator = [{}]
try:
for sequential_combination in sequential_combination_generator:
print(f"Trying {max_iterations} w/: {sequential_combination}")
for i in range(max_iterations):
# Merge the sequential params with the randomly generated params
params = {
**sequential_combination,
**{param: gen() for param, gen in param_generators.items()},
}
solution = target_function(**params)
score = score_solution(solution)
if score > best_score:
best_score = score
best_solution = (params, solution)
print(f"Iteration {i}: New best solution: {best_solution}")
if time.time() - start_time > max_time:
raise TimeoutError(f"Time limit reached")
except TimeoutError as e:
print(e)
return (best_solution, best_score)
def generate_parameter_combinations(sequential_params: dict[str, list]) -> Iterable[dict]:
# Break the sequential_params dict into keys and values
keys, values = zip(*sequential_params.items())
# Yield each combination as a dict
for combination in product(*values):
yield dict(zip(keys, combination))
def fun(constant, length, period, multiplier, factor):
a = constant * length * period * multiplier * factor
b = length * period
return (a, [b])
def sol_scorer(sol):
if max(sol[1]) < 40: # Invalid; return 0
return 0
return sol[0]
def main():
constants = [1, 2, 3, 4, 4, 3, 5, 6, 7, 8]
# All of these combinations will exhaustively tried
sequential_params = generate_parameter_combinations(
{
"length": [10, 20, 30, 40],
"period": [40, 30, 20, 10],
}
)
# You can also pass in just a list of dicts, á la
# sequential_params = [
# {"length": 10, "period": 40},
# {"length": 20, "period": 30},
# ]
# These will be randomly generated
param_generators = {
"constant": lambda: random.choice(constants),
"multiplier": lambda: random.randint(1, 100),
"factor": lambda: random.randint(1, 100),
}
res = find_solution(
fun,
sol_scorer,
param_generators=param_generators,
sequential_combination_generator=sequential_params,
max_iterations=10_000, # Limit for each sequential combination
max_time=10, # Total time limit
)
print(res)
if __name__ == "__main__":
main()
這列印出來例如
Trying 10000 w/: {'length': 10, 'period': 40}
Iteration 0: New best solution: ({'length': 10, 'period': 40, 'constant': 2, 'multiplier': 95, 'factor': 64}, (4864000, [400]))
Iteration 1: New best solution: ({'length': 10, 'period': 40, 'constant': 7, 'multiplier': 73, 'factor': 93}, (19009200, [400]))
Iteration 71: New best solution: ({'length': 10, 'period': 40, 'constant': 6, 'multiplier': 96, 'factor': 93}, (21427200, [400]))
Iteration 248: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 80, 'factor': 89}, (22784000, [400]))
Iteration 595: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 79, 'factor': 97}, (24521600, [400]))
Iteration 679: New best solution: ({'length': 10, 'period': 40, 'constant': 7, 'multiplier': 96, 'factor': 99}, (26611200, [400]))
Iteration 722: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 98, 'factor': 93}, (29164800, [400]))
Iteration 6065: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 98, 'factor': 94}, (29478400, [400]))
Trying 10000 w/: {'length': 10, 'period': 30}
Trying 10000 w/: {'length': 20, 'period': 40}
Iteration 0: New best solution: ({'length': 20, 'period': 40, 'constant': 7, 'multiplier': 70, 'factor': 79}, (30968000, [800]))
Iteration 26: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 96, 'factor': 63}, (38707200, [800]))
Iteration 54: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 81, 'factor': 78}, (40435200, [800]))
Iteration 80: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 80, 'factor': 97}, (49664000, [800]))
...
Iteration 4500: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 94, 'factor': 93}, (111897600, [1600]))
Iteration 5638: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 96, 'factor': 97}, (119193600, [1600]))
Iteration 6006: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 99, 'factor': 99}, (125452800, [1600]))
Trying 10000 w/: {'length': 40, 'period': 30}
(({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 99, 'factor': 99}, (125452800, [1600])), 125452800)
轉載請註明出處,本文鏈接:https://www.uj5u.com/ruanti/531972.html
標籤:Python测试优化
