贝叶斯分类是利用概率统计知识进行分类的算法,其分类原理是贝叶斯定理。贝叶斯定理的公式如下:
贝叶斯公式表明,我们可以从先验概率P(A)、条件概率P(B|A)和证据P(B)来计算出后验概率。
朴素贝叶斯分类器就是假设证据之间各个条件相互独立的基础上,根据计算的后验概率选择各类别后验概率最大的类别作为目标证据的类别。
构建朴素贝叶斯分类器的步骤如下:
1、根据训练样例分别计算每个类别出现的概率P(Ai),
2、对每个特征属性计算所有划分的条件概率P(Bi|Ai),
3、对每个类别计算P(B|Ai)*P(Ai),
4、选择3步骤中数值最大项作为B的类别Ak。
在实际编码中,并没有计算各个概率,而是构建了各个属性在各个类别中出现的频次数,根据目标特征计算相应的概率,这样的好处是容易存储和读取,便于使用,具体代码如下:
def bayesian(inX,tranSet,labels): '''
贝叶斯分类器
:param tranSet:特征矩阵
:param labels: 类别
:return:
'''
labelsTree = {}
m,n = tranSet.shape
labelsCount = {}
xCount = zeros((n,1))
for i in arange(m):
if labels[i] not in labelsTree:
labelsTree[labels[i]] = {}
labelsCount[labels[i]] = {}
for j in arange(n):
if j not in labelsTree[labels[i]]:
labelsTree[labels[i]][j] = {}
#labelsTree[labels[i]][tranSet[i][j]] = labelsTree[labels[i]][tranSet[i][j]].get(labels[i][tranSet[i][j]],0) + 1
labelsTree[labels[i]][j][tranSet[i,j]] = labelsTree[labels[i]][j].get(tranSet[i,j],0) + 1
labelsCount[labels[i]][j] = labelsCount[labels[i]].get(j,0) + 1
if inX[j] == tranSet[i,j]:
xCount[j] = xCount[j] + 1
pVector = {}
xProp = (xCount/sum(xCount)).cumprod()[-1]
for key in labelsTree.keys():
for i in arange(n):
pVector[key] = pVector.get(key,1) * labelsTree[key][i].get(inX[i],1)/labelsCount[key].get(i,1)
pVector[key] = pVector[key] * sum(array([x for x in labelsCount[key].values()]))/m
return pVector,array([x for x in pVector.values()],dtype = 'float')/xProp
测试代码如下:
from numpy import *import mldata = [['<=30','high','no','fair'], ['<=30','high','no','excellent'], ['31...40','high','no','fair'], ['>40','medium','no','fair'], ['>40','low','yes','fair'], ['>40','low','yes','excellent'], ['31...40','low','yes','excellent'], ['<=30','medium','no','fair'], ['<=30','low','yes','fair'], ['>40','medium','yes','fair'], ['<=30','medium','yes','excellent'], ['31...40','medium','no','excellent'], ['31...40','high','yes','fair'], ['>40','medium','no','excellent']]label = ['no','no','yes','yes','yes','no','yes','no','yes','yes','yes','yes','yes','no']inX = ['<=30','medium','yes','fair']pV = ml.bayesian(array(inX),array(data),array(label))print(pV)
本文出自 “走一停二回头看三” 博客,请务必保留此出处http://janwool.blog.51cto.com/5694960/1895088