出于學習原因,我正在使用 scikit-learn 界面創建自定義分類器。所以,我想出了以下代碼:
import numpy as np
from sklearn.utils.estimator_checks import check_estimator
from sklearn.base import BaseEstimator, ClassifierMixin, check_X_y
from sklearn.utils.validation import check_array, check_is_fitted, check_random_state
class TemplateEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, threshold=0.5, random_state=None):
self.threshold = threshold
self.random_state = random_state
def fit(self, X, y):
self.random_state_ = check_random_state(self.random_state)
X, y = check_X_y(X, y)
self.classes_ = np.unique(y)
self.fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
y_hat = self.random_state_.choice(self.classes_, size=X.shape[0])
return y_hat
check_estimator(TemplateEstimator())
這個分類器只是做隨機猜測。我盡力遵循 scikit-learn 檔案和指南來開發自己的估算器。但是,我收到以下錯誤:
AssertionError:
Arrays are not equal
Classifier cant predict when only one class is present.
Mismatched elements: 10 / 10 (100%)
Max absolute difference: 1.
Max relative difference: 1.
x: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
y: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
我不能確定,但??我猜是隨機性(即self.random_state_)導致了錯誤。我正在使用 sklearn 版本1.0.2。
uj5u.com熱心網友回復:
首先要注意的是,如果使用parametrize_with_checkswithpytest而不是check_estimator. 它看起來像:
@parametrize_with_checks([TemplateEstimator()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)
如果你用 pytest 運行它,你會得到一個包含以下失敗測驗的輸出:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_pipeline_consistency] - AssertionError:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train(readonly_memmap=True)] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train(readonly_memmap=True,X_dtype=float32)] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_regression_target] - AssertionError: Did not raise: [<class 'ValueErr...
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_methods_sample_order_invariance] - AssertionError:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_methods_subset_invariance] - AssertionError:
其中一些測驗檢查某些輸出一致性,這與您的情況無關,因為您回傳隨機值。在這種情況下,您需要設定non_deterministic estimator tag. 其他一些測驗,例如check_classifiers_regression_target檢查您是否進行了正確的驗證并引發了正確的錯誤,而您卻沒有。所以你要么需要解決這個問題,要么添加no_validation標簽。另一個問題是check_classifier_train檢查您的模型是否為給定問題提供合理的輸出。但是由于您回傳的是隨機值,因此不滿足這些條件。您可以設定poor_score估算器標簽以跳過它。
您可以通過將其添加到您的估算器來添加這些標簽:
class TemplateEstimator(BaseEstimator, ClassifierMixin):
...
def _more_tags(self):
return {
"non_deterministic": True,
"no_validation": True,
"poor_score": True,
}
但即便如此,如果您使用mainscikit-learn 的分支或 nightly 構建,兩個測驗也會失敗。我相信這需要修復,我已經為它打開了一個問題。您可以通過在標簽中將這些測驗設定為預期失敗來避免這些失敗。最后,您的估算器將如下所示:
import numpy as np
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.base import BaseEstimator, ClassifierMixin, check_X_y
from sklearn.utils.validation import check_array, check_is_fitted, check_random_state
class TemplateEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, threshold=0.5, random_state=None):
self.threshold = threshold
self.random_state = random_state
def fit(self, X, y):
self.random_state_ = check_random_state(self.random_state)
X, y = check_X_y(X, y)
self.classes_ = np.unique(y)
self.fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
y_hat = self.random_state_.choice(self.classes_, size=X.shape[0])
return y_hat
def _more_tags(self):
return {
"non_deterministic": True,
"no_validation": True,
"poor_score": True,
"_xfail_checks": {
"check_methods_sample_order_invariance": "This test shouldn't be running at all!",
"check_methods_subset_invariance": "This test shouldn't be running at all!",
},
}
@parametrize_with_checks([TemplateEstimator()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/422907.html
標籤:
