本文实例为大家分享了python实现knn算法的具体代码,供大家参考,具体内容如下
knn算法描述
对需要分类的点依次执行以下操作:
1.计算已知类别数据集中每个点与该点之间的距离
2.按照距离递增顺序排序
3.选取与该点距离最近的k个点
4.确定前k个点所在类别出现的频率
5.返回前k个点出现频率最高的类别作为该点的预测分类
knn算法实现
数据处理
1
2
3
4
5
6
7
8
9
10
|
#从文件中读取数据,返回的数据和分类均为二维数组
def loadDataSet(filename):
dataSet = []
labels = []
fr = open (filename)
for line in fr.readlines():
lineArr = line.strip().split( "," )
dataSet.append([ float (lineArr[ 0 ]), float (lineArr[ 1 ])])
labels.append([ float (lineArr[ 2 ])])
return dataSet , labels
|
knn算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
#计算两个向量之间的欧氏距离
def calDist(X1 , X2):
sum = 0
for x1 , x2 in zip (X1 , X2):
sum + = (x1 - x2) * * 2
return sum * * 0.5
def knn(data , dataSet , labels , k):
n = shape(dataSet)[ 0 ]
for i in range (n):
dist = calDist(data , dataSet[i])
#只记录两点之间的距离和已知点的类别
labels[i].append(dist)
#按照距离递增排序
labels.sort(key = lambda x:x[ 1 ])
count = {}
#统计每个类别出现的频率
for i in range (k):
key = labels[i][ 0 ]
if count.has_key(key):
count[key] + = 1
else : count[key] = 1
#按频率递减排序
sortCount = sorted (count.items(),key = lambda item:item[ 1 ],reverse = True )
return sortCount[ 0 ][ 0 ] #返回频率最高的key,即label
|
结果测试
已知类别数据(来源于西瓜书+虚构)
0.697,0.460,1
0.774,0.376,1
0.720,0.330,1
0.634,0.264,1
0.608,0.318,1
0.556,0.215,1
0.403,0.237,1
0.481,0.149,1
0.437,0.211,1
0.525,0.186,1
0.666,0.091,0
0.639,0.161,0
0.657,0.198,0
0.593,0.042,0
0.719,0.103,0
0.671,0.196,0
0.703,0.121,0
0.614,0.116,0
绘图方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
def drawPoints(data , dataSet, labels):
xcord1 = [];
ycord1 = [];
xcord2 = [];
ycord2 = [];
for i in range (shape(dataSet)[ 0 ]):
if labels[i][ 0 ] = = 0 :
xcord1.append(dataSet[i][ 0 ])
ycord1.append(dataSet[i][ 1 ])
if labels[i][ 0 ] = = 1 :
xcord2.append(dataSet[i][ 0 ])
ycord2.append(dataSet[i][ 1 ])
fig = plt.figure()
ax = fig.add_subplot( 111 )
ax.scatter(xcord1, ycord1, s = 30 , c = 'blue' , marker = 's' ,label = 0 )
ax.scatter(xcord2, ycord2, s = 30 , c = 'green' ,label = 1 )
ax.scatter(data[ 0 ], data[ 1 ], s = 30 , c = 'red' ,label = "testdata" )
plt.legend(loc = 'upper right' )
plt.show()
|
测试代码
1
2
3
4
5
|
dataSet , labels = loadDataSet( 'dataSet.txt' )
data = [ 0.6767 , 0.2122 ]
drawPoints(data , dataSet, labels)
newlabels = knn(data, dataSet , labels , 5 )
print newlabels
|
运行结果
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://blog.csdn.net/chenge_j/article/details/72110652