话不多说,直接上代码
def stacking_first(train, train_y, test):
savepath = './stack_op{}_dt{}_tfidf{}/'.format(args.option, args.data_type, args.tfidf)
os.makedirs(savepath, exist_ok=True) count_kflod = 0
num_folds = 6
kf = KFold(n_splits=num_folds, shuffle=True, random_state=10)
# 测试集上的预测结果
predict = np.zeros((test.shape[0], config.n_class))
# k折交叉验证集的预测结果
oof_predict = np.zeros((train.shape[0], config.n_class))
scores = []
f1s = [] for train_index, test_index in kf.split(train):
# 训练集划分为6折,每一折都要走一遍。那么第一个是5份的训练集索引,第二个是1份的测试集,此处为验证集是索引 kfold_X_train = {}
kfold_X_valid = {} # 取数据的标签
y_train, y_test = train_y[train_index], train_y[test_index]
# 取数据
kfold_X_train, kfold_X_valid = train[train_index], train[test_index] # 模型的前缀
model_prefix = savepath + 'DNN' + str(count_kflod)
if not os.path.exists(model_prefix):
os.mkdir(model_prefix) M = 4 # number of snapshots
alpha_zero = 1e-3 # initial learning rate
snap_epoch = 16
snapshot = SnapshotCallbackBuilder(snap_epoch, M, alpha_zero) # 使用训练集的size设定维度,fit一个模型出来
res_model = get_model(train)
res_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# res_model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=1, class_weight=class_weight)
res_model.fit(kfold_X_train, y_train, batch_size=BATCH_SIZE, epochs=snap_epoch, verbose=1,
validation_data=(kfold_X_valid, y_test),
callbacks=snapshot.get_callbacks(model_save_place=model_prefix)) # 找到这个目录下所有已经训练好的深度学习模型,通过".h5"
evaluations = []
for i in os.listdir(model_prefix):
if '.h5' in i:
evaluations.append(i) # 给测试集和当前的验证集开辟空间,就是当前折的数据预测结果构建出这么多的数据集[数据个数,类别]
preds1 = np.zeros((test.shape[0], config.n_class))
preds2 = np.zeros((len(kfold_X_valid), config.n_class))
# 遍历每一个模型,用他们分别预测当前折数的验证集和测试集,N个模型的结果求平均
for run, i in enumerate(evaluations):
res_model.load_weights(os.path.join(model_prefix, i))
preds1 += res_model.predict(test, verbose=1) / len(evaluations)
preds2 += res_model.predict(kfold_X_valid, batch_size=128) / len(evaluations) # 测试集上预测结果的加权平均
predict += preds1 / num_folds
# 每一折的预测结果放到对应折上的测试集中,用来最后构建训练集
oof_predict[test_index] = preds2 # 计算精度和F1
accuracy = mb.cal_acc(oof_predict[test_index], np.argmax(y_test, axis=1))
f1 = mb.cal_f_alpha(oof_predict[test_index], np.argmax(y_test, axis=1), n_out=config.n_class)
print('the kflod cv is : ', str(accuracy))
print('the kflod f1 is : ', str(f1))
count_kflod += 1 # 模型融合的预测结果,存起来,用以以后求平均值
scores.append(accuracy)
f1s.append(f1)
# 指标均值,最为最后的预测结果
print('total scores is ', np.mean(scores))
print('total f1 is ', np.mean(f1s))
return predict
深度学习模型stacking模型融合python代码,看了你就会使的更多相关文章
-
时间序列深度学习:seq2seq 模型预测太阳黑子
目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...
-
【转】[caffe]深度学习之图像分类模型AlexNet解读
[caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097 本文章已收录于: ...
-
[caffe]深度学习之图像分类模型VGG解读
一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...
-
深度学习 vs. 概率图模型 vs. 逻辑学
深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...
-
深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大
from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...
-
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...
-
深入浅出深度学习:原理剖析与python实践_黄安埠(著) pdf
深入浅出深度学习:原理剖析与python实践 目录: 第1 部分 概要 1 1 绪论 2 1.1 人工智能.机器学习与深度学习的关系 3 1.1.1 人工智能——机器推理 4 1.1.2 机器学习—— ...
-
一文看懂Stacking!(含Python代码)
一文看懂Stacking!(含Python代码) https://mp.weixin.qq.com/s/faQNTGgBZdZyyZscdhjwUQ
-
风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施
风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施 Python 语言可能发生的命令执行漏洞 内置危险函数 eval和exec函数 eval eval是一个python内置函数, ...
随机推荐
-
ASP.NET MVC移动M站建设-使用51Degree 移动设备的识别
上一篇,介绍了移动M站的建设.说的很简单.觉得好像也没把M站给讲清楚.估计是对移动M站 认识还不够深刻吧.这里,在讲一讲51Degree 这个组件. 51degrees 号称是目前最快.最准确的设备检 ...
-
JAVA使用JDBC技术操作SqlServer数据库执行存储过程
Java使用JDBC技术操作SqlServer数据库执行存储过程: 1.新建SQLSERVER数据库:java_conn_test 2.新建表:tb_User 3.分别新建三个存储过程: 1>带 ...
-
rpc框架之 thrift 学习 1 - 安装 及 hello world
thrift是一个facebook开源的高效RPC框架,其主要特点是跨语言及二进制高效传输(当然,除了二进制,也支持json等常用序列化机制),官网地址:http://thrift.apache.or ...
-
noip2013 火柴排序
涵涵有两盒火柴,每盒装有 n 根火柴,每根火柴都有一个高度.现在将每盒中的火柴各自排成一列,同一列火柴的高度互不相同,两列火柴之间的距离定义为: ,其中 ai 表示第一列火柴中第 i 个火柴的高度,b ...
-
安装ipython notebook
从http://cs231n.github.io/assignments2016/assignment1/开始说起,因为要学习cs231n课程,需要安装ipython notebook,原本电脑中安装 ...
-
.NET ORM框架之NHibernate
这段时间一直使用NHibernate,今天抽空总结一下. 1.什么是NHibernate? NHibernate是一个面向.NET环境的对象/关系数据库映射工具.对象/关系数据库映射(object/r ...
-
unbuntu中如何像Windows一样顺畅的切换中英文输入法
1.首先在unbuntu安装搜狗拼音输入法(这个不用教了) 2.点击右上角的搜狗拼音的图标点击设置进入设置页面 3.选择高级 4.选择Fcitx设置 5.添加输入法英语(美国) 6.在设置中选择按键, ...
-
H5多媒体(用面向对象的方法控制视频、音频播放、暂停、延时暂停)
视频,音频播放器会是我们在工作中用到的一些h5新标签,它自带一些属性,比如暂停播放,快进快退,但是,我们经常不用原生的样式或者方法,我们需要自定义这些按钮来达到我们需要的样式,也需要我们自定义来实现一 ...
-
Enigma Virtual Box:生成可执行文件。
Enigma Virtual Box Enigma Virtual Box[1] 是软件虚拟化工具,它可以将多个文件封装到应用程序主文件,从而制作成为单执行文件的绿色软件.它支持所有类型的文件格式, ...
-
failed to load response data
当需要根据后台传回地址跳转页面时 即使使用preserve log 可以查看上一个页面获取地址请求,但是此时请求返回值为failed to load response data 当关闭页面跳转可以查看 ...