【OpenCV-Python】教程:7-1 理解 kNN (k-Nearest Neighbour)

时间:2022-12-22 11:51:56

OpenCV Python 理解kNN (k-Nearest Neighbour)

【目标】

  • 理解 kNN 算法的基本概念

【理论】

kNN是监督学习中最简单的分类算法之一。其思想是在特征空间中搜索与测试数据最接近的匹配。我们将用下图来研究它。

【OpenCV-Python】教程:7-1 理解 kNN (k-Nearest Neighbour)

在图像中,有两个"家族":蓝色正方形和红色三角形。我们把每个"家族"称为一个类。他们的房子显示在他们的城镇地图上,我们称之为特征空间。您可以将特征空间视为所有数据投影的空间。例如,考虑一个2D坐标空间。每个基准有两个特征,一个x坐标和一个y坐标。你可以在二维坐标空间中表示这个数据,对吧? 现在想象有三个特征,你将需要3D空间。现在考虑N个特征:你需要N维空间,对吧? 这个N维空间是它的特征空间。在我们的图像中,您可以将其视为具有两个特征的二维情况。

现在考虑一下,如果一个新成员来到镇上,并创建了一个新家,会发生什么,如绿色圆圈所示。他应该加入这些蓝色或红色家族(或类)之一。我们把这个过程叫做分类。这个新成员究竟应该如何归类? 由于我们处理的是kNN,让我们应用算法。

一个简单的方法是看看谁是他最近的邻居。从图像上看,很明显它是红三角家族的一员。所以他被归为红三角。这种方法被称为最近邻分类,因为分类只依赖于最近邻。

但是这种方法有一个问题! 红三角可能是最近的邻居,但如果附近也有很多蓝色方块呢?那么蓝色方块在该区域比红色三角形有更多的力量,所以仅仅检查最近的一个是不够的。相反,我们可能想要检查一些k个最近的家族。然后,无论哪个家族在他们中占多数,新来的人都应该属于这个家族。在我们的图像中,让我们取k=3,即考虑3个最近的邻居。新成员有两个红色邻居和一个蓝色邻居(有两个蓝色邻居等距,但由于k=3,我们只能取其中一个),所以他应该再次加入红色家族。但如果取k=7呢?然后他有5个蓝色邻居和2个红色邻居,应该加入蓝色家族。结果将随k的选择值而变化。请注意,如果k不是奇数,我们可以得到一个平局,就像上面k=4的情况一样。我们会看到我们的新成员有2个红色和2个蓝色邻居作为他的四个最近的邻居,我们需要选择一种打破平局的方法来进行分类。重申一下,这种方法被称为k近邻,因为分类取决于k个近邻

同样,在kNN中,我们确实考虑了k个邻居,但我们对所有邻居都给予了同等的重视,对吧?这合理吗?以k=4为例。我们可以看到,2个红色邻居实际上比其他2个蓝色邻居更接近新成员,所以他更有资格加入红色家庭。我们如何从数学上解释呢?我们根据每个邻居与新来的人的距离给他们一些权重:离他近的人得到更高的权重,而离他远的人得到更低的权重。然后,我们将每个家庭的总权重分别相加,并将新来者归类为总权重较高的家庭的一部分。这被称为修正kNN或加权kNN。

你在这里看到了什么重要的东西?

  • 因为我们必须检查从新来者到所有现有房屋的距离,以找到最近的邻居,所以你需要镇上所有房屋的信息,对吗?如果有很多房子和家庭,需要大量的内存,也需要更多的时间来计算。
  • 几乎没有时间进行任何形式的“训练”或准备。我们的“学习”只包括在测试和分类之前记忆(存储)数据。

【代码】

【OpenCV-Python】教程:7-1 理解 kNN (k-Nearest Neighbour)

import cv2 
import numpy as np 
import matplotlib.pyplot as plt 

# 生成 25 个特征数据,
trainData = np.random.randint(0, 100, (25, 2)).astype(np.float32)

# 生成标签 0 或 1
responses = np.random.randint(0, 2, (25, 1)).astype(np.float32)

plt.figure()
plt.title("kNN demo")
plt.xlabel("x")
plt.ylabel("y")

# 画红色的标签
red = trainData[responses.ravel() == 0]
plt.scatter(red[:, 0], red[:, 1], 80, 'r', '^')

# 画蓝色的标签
blue = trainData[responses.ravel() == 1]
plt.scatter(blue[:, 0], blue[:, 1], 80, 'b', 's')

# 新来的数据
newcomer = np.random.randint(0, 100, (1, 2)).astype(np.float32)
plt.scatter(newcomer[:, 0], newcomer[:, 1], 80, 'g', 'o')

# 创建 kNN
knn = cv2.ml.KNearest_create()
knn.train(trainData, cv2.ml.ROW_SAMPLE, responses)
ret, results, neighbours, dist = knn.findNearest(newcomer, 3)
print("result:  {}\n".format(results))
print("neighbours:  {}\n".format(neighbours))
print("distance:  {}\n".format(dist))


# plt.show()
plt.savefig('result.png', bbox_inches='tight')
  • 输出
result:  [[1.]]

neighbours:  [[1. 1. 1.]]

distance:  [[100. 221. 377.]]

【接口】

  • KNearest_create
cv2.ml.KNearest_create(		) ->	retval

创建一个空的kNN模型
然后需要用 StatModel::train 来训练。

  • findNearest
cv2.ml_KNearest.findNearest(	samples, k[, results[, neighborResponses[, dist]]]	) ->	retval, results, neighborResponses, dist

找到最近邻的类别标签,以及对应的距离。

  • samples: 输入样本,按行存储,单精度浮点矩阵
  • k: 最近邻的数量,必须大于1
  • results: 每个输入样本的预测结果(回归或分类)向量。
  • neighborResponses: 对应邻居的标签类别
  • dist: 对应邻居的距离

其他见 OpenCV: cv::ml::KNearest Class Reference

训练见 OpenCV: cv::ml::StatModel Class Reference

【参考】

  1. OpenCV: Understanding k-Nearest Neighbour
  2. NPTEL notes on Pattern Recognition, Chapter 11
  3. Wikipedia article on Nearest neighbor search
  4. Wikipedia article on k-d tree
  5. OpenCV: cv::ml::KNearest Class Reference