引言
上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的。考虑到数据中可能存在噪音,还引入了松弛变量。
理论是抽象的,问题是具体的。站在岸上学不会游泳,光看着梨子不可能知道梨子的滋味。本篇博客就是用SVM分类算法解决一个经典的机器学习问题--手写数字识别。体会一下SVM算法的具体过程,理理它的一般性的思路。
问题的提出
人类视觉系统是世界上众多的奇迹之一。看看下面的手写数字序列:
大多数人毫不费力就能够认出这些数字为504192。如果尝试让计算机程序来识别诸如上面的数字,就会明显感受到视觉模式识别的困难。关于我们识别形状--------“9顶上有一个圈,右下方则是一条竖线”这样的简单直觉,实际上算法很难轻易表达出来。
SVM分类算法以另一个角度来考虑问题。其思路是获取大量的手写数字,常称作训练样本,然后开发出一个可以从这些训练样本中进行学习的系统。换言之,SVM使用样本来自动推断出识别手写数字的规则。随着样本数量的增加,算法可以学到更多关于手写数字的知识,这样就能够提升自身的准确性。
本文采用的数据集就是著名的“MNIST数据集”。这个数据集有60000个训练样本数据集和10000个测试用例。直接调用scikit-learn库中的SVM,使用默认的参数,1000张手写数字图片,判断准确的图片就高达9435张。
SVM的算法过程
通常,对于分类问题。我们会将数据集分成三部分,训练集、测试集、交叉验证集。用训练集训练生成模型,用测试集和交叉验证集进行验证模型的准确性。
加载数据的代码如下:
"""
mnist_loader
~~~~~~~~~~~~
一个加载模式识别图片数据的库。
"""
#### Libraries
# Standard library
import cPickle
import gzip
# Third-party libraries
import numpy as np
def load_data():
"""
返回包含训练数据、验证数据、测试数据的元组的模式识别数据
训练数据包含50,000张图片,测试数据和验证数据都只包含10,000张图片
"""
f = gzip.open('../data/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)
SVM算法进行训练和预测的代码如下:
"""
mnist_svm
~~~~~~~~~
使用SVM分类器,从MNIST数据集中进行手写数字识别的分类程序
"""
#### Libraries
# My libraries
import mnist_loader
# Third-party libraries
from sklearn import svm
import time
def svm_baseline():
print time.strftime('%Y-%m-%d %H:%M:%S')
training_data, validation_data, test_data = mnist_loader.load_data()
# 传递训练模型的参数,这里用默认的参数
clf = svm.SVC()
# clf = svm.SVC(C=8.0, kernel='rbf', gamma=0.00,cache_size=8000,probability=False)
# 进行模型训练
clf.fit(training_data[0], training_data[1])
# test
# 测试集测试预测结果
predictions = [int(a) for a in clf.predict(test_data[0])]
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))
print "%s of %s test values correct." % (num_correct, len(test_data[1]))
print time.strftime('%Y-%m-%d %H:%M:%S')
if __name__ == "__main__":
svm_baseline()
以上代码没有用验证集进行验证。这是因为本例中,用测试集和验证集要判断的是一个东西,没有必要刻意用验证集再来验证一遍。事实上,我的确用验证集也试了一下,和测试集的结果基本一样。呵呵
直接运行代码,结果如下:
2016-01-02 14:01:46
9435 of 10000 test values correct.
2016-01-02 14:12:37
在我的ubuntu上,运行11分钟左右就可以完成训练,并预测测试集的结果。
需要说明的是,svm.SVC()函数的几个重要参数。直接用help命令查看一下文档,这里我稍微翻译了一下:
C : 浮点型,可选 (默认=1.0)。误差项的惩罚参数C
kernel : 字符型, 可选 (默认='rbf')。指定核函数类型。只能是'linear', 'poly', 'rbf', 'sigmoid', 'precomputed' 或者自定义的。如果没有指定,默认使用'rbf'。如果使用自定义的核函数,需要预先计算核矩阵。
degree : 整形, 可选 (默认=3)。用多项式核函数('poly')时,多项式核函数的参数d,用其他核函数,这个参数可忽略
gamma : 浮点型, 可选 (默认=0.0)。'rbf', 'poly' and 'sigmoid'核函数的系数。如果gamma是0,实际将使用特征维度的倒数值进行运算。也就是说,如果特征是100个维度,实际的gamma是1/100。
coef0 : 浮点型, 可选 (默认=0.0)。核函数的独立项,'poly' 和'sigmoid'核时才有意义。
可以适当调整一下SVM分类算法,看看不同参数的结果。当我的参数选择为C=100.0, kernel='rbf', gamma=0.03时,预测的准确度就已经高达98.5%了。
SVM参数的调优初探
SVM分类算法需要调整的参数就只有几个。那么这些参数如何选取,有没有一些经验性的规律呢?
- 核函数选择
如上图,线性核函数的分类边界是线性的,非线性核函数分类边界是很复杂的非线性边界。所以当能直观地观察数据时,大致可以判断分类边界,从而有倾向性地选择核函数。
- 参数gamma和C的选择
机器学习大牛Andrew Ng说,关于SVM分类算法,他一直用的是高斯核函数,其它核函数他基本就没用过。可见,这个核函数应用最广。
gamma参数,当使用高斯核进行映射时,如果选得很小的话,高次特征上的权重实际上衰减得非常快,所以实际上(数值上近似一下)相当于一个低维的子空间;反过来,如果gamma选得很大,则可以将任意的数据映射为线性可分——这样容易导致非常严重的过拟合问题。
C参数是寻找 margin 最大的超平面”和“保证数据点偏差量最小”)之间的权重。C越大,模型允许的偏差越小。
下图是一个简单的二分类情况下,不同的gamma和C对分类结果的影响。
相同的C,gamma越大,分类边界离样本越近。相同的gamma,C越大,分类越严格。
下图是不同C和gamma下分类器交叉验证准确率的热力图
由图可知,模型对gamma参数是很敏感的。如果gamma太大,无论C取多大都不能阻止过拟合。当gamma很小,分类边界很像线性的。取中间值时,好的模型的gamma和C大致分布在对角线位置。还应该注意到,当gamma取中间值时,C取值可以是很大的。
在实际项目中,这几个参数按一定的步长,多试几次,一般就能得到比较好的分类效果了。
小结
回顾一下整个问题。我们进行了如下操作。对数据集分成了三部分,训练集、测试集和交叉验证集。用SVM分类模型进行训练,依据测试集和验证集的预测结果来优化参数。依靠sklearn这个强大的机器学习库,我们也能解决手写识别这么高大上的问题了。事实上,我们只用了几行简单代码,就让测试集的预测准确率高达98.5%。
SVM算法也没有想象的那么高不可攀嘛,呵呵!
事实上,就算是一般性的机器学习问题,我们也是有一些一般性的思路的,如下:
SVM学习笔记(二)----手写数字识别的更多相关文章
-
深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
-
【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
-
机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
-
机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
-
学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
-
基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
-
实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu
对手写数据集50000张图片实现阿拉伯数字0~9识别,并且对结果进行分析准确率, 手写数字数据集下载:http://yann.lecun.com/exdb/mnist/ 首先,利用图片本身的属性,图片 ...
-
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
-
MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
随机推荐
-
MySQL 5.7:非结构化数据存储的新选择
本文转载自:http://www.innomysql.net/article/23959.html (只作转载, 不代表本站和博主同意文中观点或证实文中信息) 工作10余年,没有一个版本能像MySQL ...
-
【夯实Mysql基础】记一次mysql语句的优化过程!
1. [事件起因] 今天在做项目的时候,发现提供给客户端的接口时间很慢,达到了2秒多,我第一时间,抓了接口,看了运行的sql,发现就是 2个sql慢,分别占了1秒多. 一个sql是 链接了5个表同 ...
-
对javascript this的理解
对于this的理解,大部分时间都比较模糊,最近几天做了一些研究,记录一下 首先应该明白,this是执行上下文的一个属性,它的值取决于执行上下文,执行上下文和函数调用方式相关,定义一个function的 ...
-
oracle exp imp
oracle exp/imp
-
JVMTI 中间JNI系列功能,线程安全和故障排除技巧
JVMTI 中间JNI系列功能,线程安全和故障排除技巧 jni functions 在使用 JVMTI 的过程中,有一大系列的函数是在 JVMTI 的文档中 没有提及的,但在实际使用却是很实用的. 这 ...
-
hibernate动态切换数据源
起因: 公司的当前产品,主要是两个项目集成的,一个是java项目,还有一个是php项目,两个项目用的是不同的数据源,但都是mysql数据库,因为java这边的开发工作已经基本完成了,而php那边任务还 ...
-
Uva 11609 Teams (组合数学)
题意:有n个人,选不少于一个人参加比赛,其中一人当队长,有多少种选择方案. 思路:我们首先C(n,1)选出一人当队长,然后剩下的 n-1 人组合的总数为2^(n-1),这里用快速幂解决 代码: #in ...
-
验证整数 Double 转 int 两种写法
Double 转int 1)之前一直是使用强转 Double num = Double.parseDouble(object.toString()); int n = (int)num; i ...
-
使用docker加载已有镜像安装Hyperledger Fabric v1.1.0
背景 每次在新的服务器上安装Hyperledger Fabric网络时,通过fabric官方提供的脚本安装时,需要从网络上down下近10G的fabric相关镜像,这个过程是漫长及痛苦的,有时因网络问 ...
-
mininet下建立拓扑时关于远程控制器的一个小问题
最近重装了系统和mininet后,使用mininet时遇到了一点小问题,一开始忽视了细节,使得自己被这个问题困扰了好一会儿,好在后来还是发现了问题所在,故记录下来. $ sudo mn --topo ...