眾所周知,線性回歸(Linear Regression)是最常見的機器學習演算法之一,簡單但超級實用,線性回歸旨在用線性方程來擬合資料分布,在資料量小計算速度要求高的地方是神經網路的最佳替代品,
LR的一般表現形式為:
y
=
w
?
T
x
?
+
b
y = \vec{w}^T\vec{x} + b
y=w
Tx
+b
通常,LR優化方式可以通過構建均方誤差損失函式,得到一個凸函式,計算導數為0的位置來確定
w
?
\vec{w}
w
和
b
b
b,如周志華老師西瓜書里描述的那樣,
在工程上,我們可以把LR當做一個簡易的神經網路來對待,用梯度下降演算法就可以優化,本文提供一個梯度下降演算法優化LR的實驗例子,有助于加深大家對LR以及梯度下降的理解,
實驗意圖:
假設有一個絕對正確的函式
y
=
f
(
x
?
;
w
?
,
b
)
=
w
?
T
x
?
+
b
y=f(\vec{x};\vec{w},b)=\vec{w}^T\vec{x}+b
y=f(x
;w
,b)=w
Tx
+b,每輸入一個
x
?
\vec{x}
x
都可以得到一個準確的
y
y
y,那么,咱們只需要得到最真實的
w
?
\vec{w}
w
和
b
b
b即可,假設最真實的
w
?
=
[
3
,
1
,
4
,
1
,
5
,
9
,
2
,
6
]
\vec{w}=[3,1,4,1,5,9,2,6]
w
=[3,1,4,1,5,9,2,6],最真實的
b
=
3.7
b=3.7
b=3.7,
初始化
w
?
\vec{w}
w
和
b
b
b為亂數,通過大量樣本的梯度反傳來修正
w
?
\vec{w}
w
和
b
b
b到真實的值,

實驗環境:
python3.7
numpy >=1.15.1
先申明,以下代碼為本人原創,借用最好在評論中告知我,
########################################################
# @author: MuZhan
# @contact: levio.pku@gmail.com
# experiment: using GD to optimize Linear Regression
# To fit `y=w*x+b`, where x and w are multi-dim vectors.
########################################################
import numpy as np
# initial setting
np.random.seed(10)
epochs = 30
lr = .1 # learning rate
w_ = np.array([3, 1, 4, 1, 5, 9, 2, 6]) # the ground truth w
b_ = 3.7 # the ground truth b
SAMPLE_NUM = 100
x_dim = len(w_)
# preparing random (x, y) pairs
print('preparing data...')
x_list = []
y_list = []
for i in range(SAMPLE_NUM):
x = np.random.rand(x_dim)
y = w_.dot(x) + b_
x_list.append(x)
y_list.append(y)
# init w
np.random.seed(10)
w = np.random.rand(x_dim)
# init b
b = 1
# training
print('training...')
for e in range(epochs):
print('epoch: ', e, end='\t')
sum_loss = 0
for i in range(len(x_list)):
x = x_list[i]
y_ = y_list[i]
y = w.dot(x) + b
loss = (y - y_) ** 2
sum_loss += loss
# use Gradient Descent to update parameters
w = w - 2 * lr * (y - y_) * x
b = b - 2 * lr * (y - y_)
print('loss: ', sum_loss)
print('Ground Truth w: ', w_, end='\t')
print('learned w: ', w)
print('Ground Trueh b: ', b_, end='\t')
print('learned b: ', b)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/353245.html
標籤:AI
