我的陣列長度為 520,metrics.roc_curve 只顯示了幾個 fpr、tpr、threshold
這些是我的分數陣列的一些值
[... 4.6719894 5.3444934 2.575739 3.5660675 3.4357991 4.195427
4.120169 5.021058 5.308503 5.3124313 4.8253884 4.7469654
5.0011086 5.170149 4.5555115 4.4109273 4.6183085 4.356304
4.413242 4.1186514 5.0573816 4.646429 5.063631 4.363433
5.431669 6.1605806 6.1510544 4.8733225 6.0209446 6.5198536
5.1457767 1.3887328 1.3165888 1.143339 1.717379 1.6670974
1.1816382 1.2497046 1.035109 1.4904765 1.195155 1.2590547
1.0998954 1.6484532 1.5722921 1.2841778 1.1058662 1.3368237
1.3262213 1.215088 1.4224783 1.046008 1.262415 1.2319984
1.2202312 1.1610713 1.2327379 1.1951761 1.8699458 0.98760885
1.6670336 1.5051543 1.2339936 1.5215651 1.534271 1.1805111
1.1587876 1.0894692 1.1936147 1.3278677 1.2409594 1.0499009... ]
我只得到這些結果
fpr [0. 0. 0. 0.00204499 0.00204499 1. ]
tpr [0. 0.03225806 0.96774194 0.96774194 1. 1. ]
threshold [7.5198536 6.5198536 3.4357991 2.5991373 2.575739 0.8769072]
這是什么原因?
uj5u.com熱心網友回復:
這可能取決于 的引數的默認值drop_intermediate(默認為 true)roc_curve(),這意味著洗掉次優閾值,doc here。您可以通過傳遞drop_intermediate=False, 來防止這種行為。
下面是一個例子:
import numpy as np
try:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist["target"] = mnist["target"].astype(np.int8)
except ImportError:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict
X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
sdg_clf = SGDClassifier(random_state=42, verbose=0)
sdg_clf.fit(X_train, y_train_5)
y_scores = cross_val_predict(sdg_clf, X_train, y_train_5, cv=3, method='decision_function')
# ROC Curves
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
len(thresholds), len(fpr), len(tpr)
# (3472, 3472, 3472)
# for roc curves, differently than for precision/recall curves, the length of thresholds and the other outputs do depend on drop_intermediate option, meant for dropping suboptimal thresholds
fpr_, tpr_, thrs = roc_curve(y_train_5, y_scores, drop_intermediate=False)
len(fpr_), len(tpr_), len(thrs)
# (60001, 60001, 60001)
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/352704.html
