机器学习——KNN算法及手写数字的识别(一)
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
搬出一张最常见的图,来直观的展示什么是KNN算法:
kNN比较好理解,其一般过程如下:
对未知类别属性的数据集中的每个点依次执行以下操作:
1、计算已知类别数据集中的点与当前点之前的距离
2、按照距离递增次序排序
3、选取与当前点距离最小的k个点
4、确定前k个点所在类别的出现概率
5、返回前k个点出现频率最高的类别作为当前点的预测分类
下面我们给出一个利用kNN算法实现手写数字识别的例子,这个例子在Machine Learning in Action一书中是用Python描述的,这里为了加深理解我用C++进行重写。
1、读取指定文件夹下的所有训练样本文件名:
// path: 路径, files 文件名, format 文件格式 [3/7/2015 pan]
void GetAllFormatFiles(string path, vector<string>& files, string format)
{
//文件句柄
long hFile = 0;
//文件信息
struct _finddata_t fileinfo;
string p;
if ((hFile = _findfirst(p.assign(path).append("\\*" + format).c_str(), &fileinfo)) != -1)
{
do
{
if ((fileinfo.attrib & _A_SUBDIR))
{
if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
{
//files.push_back(p.assign(path).append("\\").append(fileinfo.name) );
GetAllFormatFiles(p.assign(path).append("\\").append(fileinfo.name), files, format);
}
}
else
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name));
}
} while (_findnext(hFile, &fileinfo) == 0);
_findclose(hFile);
}
}
2、返回原始队列索引的排序算法
vector<int> insertSort(vector<int> nums)
{
vector<int> sortedIndx;
for (int i = 0; i<nums.size(); i++){
sortedIndx.push_back(i);
}
for (int j = 1; j < nums.size(); j++){
int key = nums[j];
int indx = sortedIndx[j];
int i = j - 1;
while (i>=0&&nums[i]>key)
{
nums[i + 1] = nums[i];
sortedIndx[i + 1] = sortedIndx[i];
i--;
}
nums[i + 1] = key;
sortedIndx[i + 1] = indx;
}
return sortedIndx;
}
3、读取文本文件内容,存入一维向量
vector<int> imageToVector(string fileName)
{
vector<int> returnVector;
fstream infile;
infile.open(fileName, ios::in);
while (!infile.eof())
{
char buffer[256];
infile.getline(buffer, 256);
for (int i = 0; i < 32; i++){
returnVector.push_back(buffer[i] - 48);// 字符转int [3/4/2015 pan]
}
}
return returnVector;
}
4、KNN算法的实现
// inX 待分类向量,dataSet 训练数据集,labels 训练数据的类别(0,1,2,3,4,5,6,7,8,9),k [3/4/2015 pan]
int classify(vector<int> inX, vector<vector<int>> dataSet, vector<int> labels, int k)
{
int dataSetSize = dataSet.size();
int labelsum = 0;
vector<int> distances;
for (int i = 0; i < dataSetSize; i++){
int sum = 0;
for (int j = 0; j < inX.size(); j++){
int tmp = inX[j] - dataSet[i][j];
tmp *= tmp;
sum += tmp;
}
sum=sqrt(sum);
distances.push_back(sum);
}
vector<int> sortedDistIndix;
sortedDistIndix = insertSort(distances);
for (int i = 0; i < k; i++){
labelsum += labels[sortedDistIndix[i]];
}
return labelsum / k + 0.5;
}
5、最终测试
int _tmain(int argc, _TCHAR* argv[])
{
string path = "G:\\A编程练习\\机器学习&Python\\handWritingTest\\digits\\trainingDigits";
string format = "txt";
vector<string> files;
GetAllFormatFiles(path, files, format);
fstream infile;
vector<vector<int>> traingMat;
vector<int> labels;
for (int i = 0; i < files.size(); i++){
//cout << path.size();
string str;
str.assign(files[i], path.size()+1,1);
const char *c = str.c_str();
labels.push_back(*c - 48);
traingMat.push_back(imageToVector(files[i]));
}
string testFileName = "G:\\A编程练习\\机器学习&Python\\handWritingTest\\digits\\testDigits\\8_73.txt";
vector<int> inX = imageToVector(testFileName);
int result = classify(inX, traingMat, labels, 8);
return 0;
}
手写数字样本:
样本下载链接:
http://download.csdn.net/detail/panan160/8480055