GridSearch ?
参考链接
GridSearchCV存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。
but这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。
数据量比较大的时候可以使用一个快速调优的方法——坐标下降。它其实是一种贪心算法:拿当前对模型影响最大的参数调优,直到最优化;再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,但是省时间省力,巨大的优势面前,还是试一试吧,后续可以再拿bagging再优化。
回到sklearn里面的GridSearchCV,GridSearchCV用于系统地遍历多种参数组合,通过交叉验证确定最佳效果参数。
参数解读
1、estimator: 分类器
如:estimator=RandomForestClassifier(min_samples_split=100,min_samples_leaf=20,max_depth=8,max_features=‘sqrt’,random_state=10), 并且传入除需要确定最佳的参数之外的其他参数。
2、param_grid: 值为字典或者列表,即主要最优化的参数的取值。
如:param_grid = param_test1,param_test1={‘n_estimators’:range(10,71,10)}。
3、scoring:准确度评价标准,默认None, 如果为None,则使用estimator的误差估计函数。
其中参数如下:
Scoring | Function | Comment |
---|---|---|
Classification | ||
accuracy | metrics.accuracy_score | |
balanced_accuracy | metrics.balanced_accuracy_score | for binary targets |
average_precision | metrics.average_precision_score | |
brier_score_loss | metrics.brier_score_loss | |
f1 | metrics.f1_score | for binary targets |
f1_micro | metrics.f1_score | micro-averaged |
f1_macro | metrics.f1_score | macro-averaged |
f1_weighted | metrics.f1_score | weighted average |
f1_samples | metrics.f1_score | by multilabel sample |
neg_log_loss | metrics.log_loss | requires predict_proba support |
‘precision’ etc. | metrics.precision_score | suffixes apply as with f1 |
‘recall’ etc. | metrics.recall_score | suffixes apply as with f1 |
roc_auc | metrics.roc_auc_score | |
Clustering | ||
adjusted_mutual_info_score | metrics.adjusted_mutual_info_score | |
adjusted_rand_score | metrics.adjusted_rand_score | |
completeness_score | metrics.completeness_score | |
fowlkes_mallows_score | metrics.fowlkes_mallows_score | |
homogeneity_score | metrics.homogeneity_score | |
mutual_info_score | metrics.mutual_info_score | |
normalized_mutual_info_score | metrics.normalized_mutual_info_score | |
v_measure_score | metrics.v_measure_score | |
Regression | ||
explained_variance | metrics.explained_variance_score | |
neg_mean_absolute_error | metrics.mean_absolute_error | |
neg_mean_squared_error | metrics.mean_squared_error | |
neg_mean_squared_log_error | metrics.mean_squared_log_error | |
neg_median_absolute_error | metrics.median_absolute_error | |
r2 | metrics.r2_score |
4、CV:交叉验证参数,默认为None,使用3折交叉验证。指定fold数量,也可以是yield训练\测试数据的生成器。
5、n_jobs:并行数,int:个数,-1:跟CPU核数一直,默认值为1。
输出结果
():运行网格搜索
grid_scores_:给出不同参数情况下的评价结果
best_params_:描述了已取得最佳结果的参数的组合
best_score_:成员提供优化过程期间观察到的最好的评分