K近邻算法(或简称kNN)是易于理解和实现的算法,而且是你解决问题的强大工具。
什么是kNN
kNN算法的模型就是整个训练数据集。当需要对一个未知数据实例进行预测时,kNN算法会在训练数据集中搜寻k个最相似实例。对k个最相似实例的属性进行归纳,将其作为对未知实例的预测。
相似性度量依赖于数据类型。对于实数,可以使用欧式距离来计算。其他类型的数据,如分类数据或二进制数据,可以用汉明距离。
对于回归问题,会返回k个最相似实例属性的平均值。对于分类问题,会返回k个最相似实例属性出现最多的属性。
kNN如何工作
kNN属于基于实例算法簇的竞争学习和懒惰学习算法。
基于实例的算法运用数据实例(或数据行)对问题进行建模,进而做出预测决策。kNN算法算是基于实例方法的一种极端形式,因为其保留所有的训练集数据作为模型的一部分。
kNN是一个竞争学习算法,因为为了做出决策,模型内部元素(数据实例)需要互相竞争。 数据实例之间客观相似度的计算,促使每个数据实例都希望在竞争中“获胜”或者尽可能地与给定的未知数据实例相似,继而在预测中做出贡献。
懒惰学习是指直到需要预测时算法才建立模型。它很懒,因为它只在最后一刻才开始工作。优点是只包含了与未知数据相关的数据,称之为局部模型。缺点是,在大型训练数据集中会重复相同或相似的搜索过程,带来昂贵的计算开销。
最后,kNN的强大之处在于它对数据不进行任何假设,除了任意两个数据实例之间距离的一致计算。因此,它被称为成为无参数或者非线性的,因为它没有预设的函数模型。
用python写程序真的好舒服。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
import numpy as np
def read_data(filename):
'''读取文本数据,格式:特征1 特征2 …… 类别'''
f = open (filename, 'rt' )
row_list = f.readlines() #以每行作为列表
f.close()
data_array = []
labels_vector = []
while True :
if not row_list:
break
row = row_list.pop( 0 ).strip().split( '\t' ) #去除换行号,分割制表符
temp_data_row = [ float (a) for a in row[: - 1 ]] #将字符型转换为浮点型
data_array.append(temp_data_row) #取特征值
labels_vector.append(row[ - 1 ]) #取最后一个作为类别标签
return np.array(data_array),np.array(labels_vector)
def classify(test_data,dataset,labels,k):
'''分类'''
diff_dis_array = test_data - dataset #使用numpy的broadcasting
dis_array = (np.add. reduce (diff_dis_array * * 2 ,axis = - 1 )) * * 0.5 #求距离
dis_array_index = np.argsort(dis_array) #升序距离的索引
class_count = {}
for i in range (k):
temp_label = labels[dis_array_index[i]]
class_count[temp_label] = class_count.get(temp_label, 0 ) + 1 #获取类别及其次数的字典
sorted_class_count = sorted (class_count.items(), key = lambda item:item[ 1 ],reverse = True ) #字典的值按降序排列
return sorted_class_count[ 0 ][ 0 ] #返回元组列表的[0][0]
def normalize(dataset):
'''数据归一化'''
return (dataset - dataset. min ( 0 )) / (dataset. max ( 0 ) - dataset. min ( 0 ))
k = 3 #近邻数
test_data = [ 0 , 0 ] #待分类数据
data,labels = read_data( 'testdata.txt' )
print ( '数据集:\n' ,data)
print ( '标签集:\n' ,labels)
result = classify(test_data,normalize(data),labels,k)
print ( '分类结果:' ,result)
|
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:https://www.cnblogs.com/LCcnblogs/p/6262034.html