注:練習來自于吳恩達機器學習
翻譯后的題目:
你是一個餐廳的老板,你想在其他城市開分店,所以你得到了一些資料(資料在本文最下方),資料中包括不同的城市人口數和該城市帶來的利潤,第一列是城市的人口數,第二列是在這個城市開店所帶來的利潤數,
現在,假設θ0和θ1都是0,計算CostFunction,即計算損失函式
首先,本題線性回歸的公式應該是這樣的:
H(θ) = θ0 + θ1*X
簡單的來說,本題中,θ0和θ1都為0,即求H(θ) = 0的損失值,
然后我們再給出損失的定義:
損失,通俗的來講,即你預測值和給定值的差
這樣就得出了損失函式J(θ)的定義:
m為資料的總條數,即m代表有幾條資料,
第一步,導包
import numpy as npimport pandas as pdimport matplotlib.pyplot as plt
第二步,把資料讀入,然后把圖打出來看一下:
path = 'ex1data1.txt'data = pd.read_csv(path, header=None, names=['Population', 'Profit'])data.plot(kind='scatter', x='Population', y='Profit', figsize=(12, 8))plt.show()
圖:

第三步,定義一下costFunction
def computeCost(X, y, theta): inner = np.power(((X * theta.T) - y), 2) return np.sum(inner) / (2 * len(X))
第四步,然后把X從data分出來,Y從data分出來,在X的左邊再加一列1,
分出來后的結果為,X為97行2列,Y為97行1列,θ為1行2列,
costFunction是計算矩陣X*矩陣θ的轉置得到的值來和真實的Y值比較,計算Cost
data.insert(0, 'Ones', 1)rows = data.shape[0]cols = data.shape[1]X = data.iloc[:, 0:cols - 1]Y = data.iloc[:, cols - 1:cols]theta = np.mat('0,0')X = np.mat(X.values)Y = np.mat(Y.values)cost = computeCost(X, Y, theta)print(cost)
標準答案:
32.072733877455676
附資料集ex1data1.txt

6.1101,17.5925.5277,9.13028.5186,13.6627.0032,11.8545.8598,6.82338.3829,11.8867.4764,4.34838.5781,126.4862,6.59875.0546,3.81665.7107,3.252214.164,15.5055.734,3.15518.4084,7.22585.6407,0.716185.3794,3.51296.3654,5.30485.1301,0.560776.4296,3.65187.0708,5.38936.1891,3.138620.27,21.7675.4901,4.2636.3261,5.18755.5649,3.082518.945,22.63812.828,13.50110.957,7.046713.176,14.69222.203,24.1475.2524,-1.226.5894,5.99669.2482,12.1345.8918,1.84958.2111,6.54267.9334,4.56238.0959,4.11645.6063,3.392812.836,10.1176.3534,5.49745.4069,0.556576.8825,3.911511.708,5.38545.7737,2.44067.8247,6.73187.0931,1.04635.0702,5.13375.8014,1.84411.7,8.00435.5416,1.01797.5402,6.75045.3077,1.83967.4239,4.28857.6031,4.99816.3328,1.42336.3589,-1.42116.2742,2.47565.6397,4.60429.3102,3.96249.4536,5.41418.8254,5.16945.1793,-0.7427921.279,17.92914.908,12.05418.959,17.0547.2182,4.88528.2951,5.744210.236,7.77545.4994,1.017320.341,20.99210.136,6.67997.3345,4.02596.0062,1.27847.2259,3.34115.0269,-2.68076.5479,0.296787.5386,3.88455.0365,5.701410.274,6.75265.1077,2.05765.7292,0.479535.1884,0.204216.3557,0.678619.7687,7.54356.5159,5.34368.5172,4.24159.1802,6.79816.002,0.926955.5204,0.1525.0594,2.82145.7077,1.84517.6366,4.29595.8707,7.20295.3054,1.98698.2934,0.1445413.394,9.05515.4369,0.61705ex1data1.txt
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/73672.html
標籤:其他


