GrideSearchCV 优化算法参数

时间:2021-11-17 13:00:24

很多机器学习算法有参数,比如 linear_model.LogisticRegression()中有参数C.

sklearn中的GrideSearchCV可方便调参过程.如下:

import numpy as np
from sklearn import linear_model
from skearn.grid_search import GridSearchCV # read my data
X_train, X_test, y_train, y_test=load_data(file='total_data.csv', X_start=2, X_end=37, y_position=64, classification=False) # 参数 C 的搜索空间
Cs = np.logspace(-1, 1, num = 100) model = linear_model.LogisticRegression()
grid = GridSearchCV(estimator=model, param_grid=dict(C=Cs))
grid.fit(X_train, y_train) print grid
print grid.best_score_
print grid.best_estimator_

输出:

GridSearchCV(cv=None, error_score='raise',
estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
verbose=0, warm_start=False),
fit_params={}, iid=True, n_jobs=1,
param_grid={'C': array([ 0.1 , 0.10476, ..., 9.54548, 10. ])},
pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
0.694444444444
LogisticRegression(C=0.23101297000831597, class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
solver='liblinear', tol=0.0001, verbose=0, warm_start=False)

要注意grid = GridSearchCV(estimator=model, param_grid=dict(C=Cs))的C, C必须是LogisticRegression的一个参数名字,否则报错.