資料集:乳腺癌資料集(from sklearn.datasets import load_breast_cancer),
(1)將樣本集劃分為70%的訓練集,30%作為測驗集,分別用邏輯回歸演算法和KNN演算法(需要先對資料進行標準化)建模(不指定引數),輸出其測驗結果的混淆矩陣,計算其準確率、查全率和假正率,
(2)利用搜索網格,分別確定邏輯回歸及KNN模型的最優引數,
KNN演算法的主要引數提示:
①n_neighbors(最近鄰個數)
取值一般為奇數,
②algorithm(用于計算最近鄰的演算法)
取值有‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’等,默認為‘auto’,注意:演算法選擇不影響KNN的最終結果,只影響模型的性能(計算的快慢程度),
③p(Minkowski距離的指標引數)
默認取p=2,即歐氏距離,而p=1為曼哈頓距離,如果需要使用非明氏距離的其它指標,應修改metric引數的值,
④weights(權重)
預測中使用的權重函式,可能的取值:‘uniform’:統一權重,即每個鄰域中的所有點均被加權, ‘distance’:權重點與其距離的倒數,在這種情況下,查詢點的近鄰比遠處的近鄰具有更大的影響力,
(3)對整個資料集使用K折交叉驗證方式(k=2,3,4,5,6,7,8,9,10),分別用邏輯回歸和KNN建模(用上一步確定的最優引數),繪圖對比兩種模型在k取不同值下的的分類準確率,

#!/usr/bin/env python
# coding: utf-8
from sklearn.datasets import load_breast_cancer
import numpy as np
from sklearn import linear_model, model_selection
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
np.set_printoptions(suppress= True)
np.set_printoptions(precision=4)
# from pylab import mpl
# mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默認字體:解決plot不能顯示中文問題
# mpl.rcParams['axes.unicode_minus'] = False # 解決保存影像是負號'-'顯示為方塊的問題
dataset = datasets.load_breast_cancer()
data = dataset.data
target = dataset.target
x_train, x_test, y_train, y_test = model_selection.train_test_split(data,target,
test_size=0.3,random_state=1)
model_logic = LogisticRegression(max_iter=10000).fit(x_train, y_train.ravel())
print(model_logic.score(x_test,y_test))
y_pred = model_logic.predict(x_test)
#測驗邏輯回歸的模型評估
tn, fp, fn, tp = confusion_matrix(y_test, y_pred,labels=[0,1]).ravel()
print(tn,fp,fn,tp)
𝑎𝑐𝑐𝑢𝑟𝑎𝑐𝑦 = (tn+tp)/(tn+tp+fn+fp)
trp = (tp)/(tp+fn)
fpr = (fp)/(tn+fp)
print("準確率為:{}%".format(accuracy*100))
print("查全率為:{}%".format(trp*100))
print("假正率為:{}%".format(fpr*100))
param = {'penalty':['l2','l1'],'C': [0.001, 0.01, 0.1,1],
'class_weight':['balanced',None],'multi_class':['ovr'],'solver':['liblinear']}
gc = GridSearchCV(model_logic, param_grid=param, cv=10)
gc.fit(x_train, y_train)
print("在測驗集上的準確率(得分):",gc.score(x_test,y_test))
print("交叉驗證的最好結果:",gc.best_score_)
print("最佳引陣列合:",gc.best_params_)
k=[2,3,4,5,6,7,8,9,10]
scores = []
model_logic = LogisticRegression(max_iter=10000,C=1, class_weight= 'balanced',
multi_class='ovr', penalty='l1', solver='liblinear')
for i in range(0,len(k)):
score = cross_val_score(model_logic,data,target,cv=k[i])
scores.append(score.mean())
print(scores)
plt.figure()
plt.title('邏輯回歸模型的k折交叉驗證得分曲線圖')
plt.plot(k,scores,'bs-')
#對訓練集和測驗集的X正則化
standardizer = StandardScaler()
X_std = standardizer.fit_transform(x_train)
standardizer = StandardScaler()
X_std_test = standardizer.fit_transform(x_test)
knn = KNeighborsClassifier ().fit(X_std, y_train)
print(knn.score(X_std_test,y_test))
y_pred = knn.predict(X_std_test)
print(y_pred)
#注意陽性為1 True Positive
#測驗KNN的模型評估
tn, fp, fn, tp = confusion_matrix(y_test, y_pred,labels=[0,1]).ravel()
print(tn,fp,fn,tp)
𝑎𝑐𝑐𝑢𝑟𝑎𝑐𝑦 = (tn+tp)/(tn+tp+fn+fp)
trp = (tp)/(tp+fn)
fpr = (fp)/(tn+fp)
print("準確率為:{}%".format(accuracy*100))
print("查全率為:{}%".format(trp*100))
print("假正率為:{}%".format(fpr*100))
param = {'n_neighbors': [1, 3, 5],'algorithm': ['auto','ball_tree', 'kd_tree', 'brute'],'p':[1,2],'weights':['uniform','distance']}
gc = GridSearchCV(knn, param_grid=param, cv=5)
gc.fit(x_train,y_train)
print("在測驗集上的準確率:",gc.score(x_test,y_test))
print("交叉驗證的最好結果:",gc.best_score_)
print("最佳引陣列合:",gc.best_params_)
k=[2,3,4,5,6,7,8,9,10]
scores = []
knn = KNeighborsClassifier(algorithm='auto', n_neighbors= 3, p= 1,weights='uniform')
for i in range(0,len(k)):
score = cross_val_score(knn,data,target,cv=k[i])
scores.append(score.mean())
print(scores)
plt.figure()
plt.title('knn的k折交叉驗證得分曲線圖')
plt.plot(k,scores,'bs-')
注意,對于畫曲線圖中文亂碼問題:
from pylab import mpl
mpl.rcParams[‘font.sans-serif’] = [‘SimHei’]
對于knn的k折交叉驗證和網格搜索,應該也需要對測驗的資料進行標準化,
因版本問題導致的引數設定,特別是演算法選擇的引數無法設定
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/374538.html
標籤:AI
