我開始學習機器學習,張量。在我開始顯示圖形之前,一切都正常。當我使用二次方程時,我輸出了它,在使用對數后,我開始輸出錯誤。現在,我是python的新手,很高興得到任何幫助。
import torch
import numpy as np
import matplotlib.pyplot as plt
x = torch.tensor([[5., 10.],
[1., 2.]], requires_grad=True)
var_history = []
fn_history = []
alpha =0.001
優化器 = torch.optimation.SGD([x], lr=alpha)
def function_parabola(variable)。
return (variable 7).log() .log() .prod()
def make_gradient_step(function, variable) 。
function_result = function(variable)
function_result.backward()
優化器.步驟()
優化器.zero_grad()
for i in range(500)。
var_history.append(x.data.numpy().copy())
fn_history.append(function_parabola(x).data.cpu().numpy().copy())
make_gradient_step(function_parabola, x)
print(x)
def show_contours(objective,
x_lims=[-10.0, 10.0] 。
y_lims=[-10.0, 10.0]。
x_ticks=100,
y_ticks=100)。)
x_step = (x_lims[1] - x_lims[0] ) / x_ticks
y_step = (y_lims[1] - y_lims[0] ) / y_ticks
X, Y = np.mgrid[x_lims[0]:x_lims[1] :x_step, y_lims[0]:y_lims[1] :y_step]
res = []
for x_index in range(X.shape[0)。
res.append([])
for y_index in range(X.shape[1]) 。
x_val = X[x_index, y_index] 。
y_val = Y[x_index, y_index].
res[-1].append(objective(np.array([x_val, y_val]).T))
res = np.array(res)
plt.figure(figsize=(7,7)
plt.contour(X, Y, res, 100)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
show_contours(function_parabola)
plt.scatter(np.array(var_history)[:,0], np.array(var_history)[:, 1], s=10, c='r') 。
plt.show()
Traceback (most recent call last):
檔案 "C:UsersKPPycharmProjectspythonProjectHomeWorkClassWork.py", 行 25, in < module>
fn_history.append(function_parabola(x).data.cpu().detach().numpy().copy()
檔案 "C:UsersKPPycharmProjectspythonProjectHomeWorkClassWork.py", line 13, in function_parabola
return np.prod(np.log(np.log(變數 7))
檔案 "C:UsersKPPycharmProjectspythonProjectvenvlibsite-packages orch\_tensor.py", line 643, in __array__
return self.numpy()
RuntimeError: Can not call numpy() on Tensor that requires grad. 使用tensor.detach().numpy()代替。
uj5u.com熱心網友回復:
嘗試將function_parabola改為:
def function_parabola(variable)。
return np.prod(np.log(np.log(variable 7) )
prod和log函式在numpy模塊中。
轉載請註明出處,本文鏈接:https://www.uj5u.com/caozuo/309448.html
標籤:
上一篇:多個后續連接的Spark性能問題
