朴素贝叶斯算法java实现(多项式模型)

时间:2021-01-07 03:11:23

网上有很多对朴素贝叶斯算法的说明的文章,在对算法实现前,参考了一下几篇文章:

NLP系列(2)_用朴素贝叶斯进行文本分类(上)

NLP系列(3)_用朴素贝叶斯进行文本分类(下)

带你搞懂朴素贝叶斯分类算法

其中“带你搞懂朴素贝叶斯算法”在我看来比较容易理解,上面两篇比较详细,更深入。

算法java实现

第一步对训练集进行预处理,分词并计算词频,得到存储训练集的特征集合

/**
* 所有训练集分词特征集合
* 第一个String代表分类标签,也就是存储该类别训练集的文件名
* 第二个String代表某条训练集的路径,这里存储的是该条语料的绝对路径
* Map<String, Integer>存储的是该条训练集的特征词和词频
*
*/
private static Map<String, Map<String, Map<String, Integer>>> allTrainFileSegsMap = new HashMap<String, Map<String, Map<String, Integer>>>();
/**
* 放大因子
* 在计算中,因各个词的先验概率都比较小,我们乘以固定的值放大,便于计算
*/
private static BigDecimal zoomFactor = new BigDecimal(10); /**
* 对传入的训练集进行分词,获取训练集分词后的词和词频集合
* @param trainFilePath 训练集路径
*/
public static void getFeatureClassForTrainText(String trainFilePath){
//通过将训练集路径字符串转变成抽象路径,创建一个File对象
File trainFileDirs = new File(trainFilePath);
//获取该路径下的所有分类路径
File[] trainFileDirList = trainFileDirs.listFiles();
if (trainFileDirList == null){
System.out.println("训练数据集不存在");
}
for (File trainFileDir : trainFileDirList){
//读取该分类下的所有训练文件
List<String> fileList = null;
try {
fileList = FileOptionUtil.readDirs(trainFileDir.getAbsolutePath());
if (fileList.size() != 0){
//遍历训练集目录数据,进行分词和类别标签处理
for(String filePath : fileList){
System.out.println("开始对此训练集进行分词处理:" + filePath);
//分词处理,获取每条训练集文本的词和词频
//若知道文件编码的话,不要用下述的判断编码格式了,效率太低
// Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, FileOptionUtil.getCodeString(filePath)));
Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, "gbk"));
if (allTrainFileSegsMap.containsKey(trainFileDir.getName())){
Map<String, Map<String, Integer>> allSegsMap = allTrainFileSegsMap.get(trainFileDir.getName());
allSegsMap.put(filePath, contentSegs);
allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap);
} else {
Map<String, Map<String, Integer>> allSegsMap = new HashMap<String, Map<String, Integer>>();
allSegsMap.put(filePath, contentSegs);
allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap);
}
}
} else {
System.out.println("该分类下没有待训练语料");
}
} catch (IOException e) {
e.printStackTrace();
}
}
}

第二步计算类别的先验概率

/**
* 计算类别C的先验概率
* 先验概率P(c)= 类c下单词总数/整个训练样本的单词总数
* @param category
* @return 类C的先验概率
*/
public static BigDecimal prioriProbability(String category){
BigDecimal categoryWordsCount = new BigDecimal(categoryWordCount(category));
BigDecimal allTrainFileWordCount = new BigDecimal(getAllTrainCategoryWordsCount());
return categoryWordsCount.divide(allTrainFileWordCount, 10, BigDecimal.ROUND_CEILING);
}

第三步计算特征词的条件概率

/**
* 多项式朴素贝叶斯类条件概率
* 类条件概率P(IK|c)=(类c下单词IK在各个文档中出现过的次数之和+1)/(类c下单词总数+|V|)
* V是训练样本的单词表(即抽取单词,单词出现多次,只算一个),
* |V|则表示训练样本包含多少种单词。 P(IK|c)可以看作是单词tk在证明d属于类c上提供了多大的证据,
* 而P(c)则可以认为是类别c在整体上占多大比例(有多大可能性)
* @param category
* @param word
* @return
*/
public static BigDecimal categoryConditionalProbability(String category, String word){
BigDecimal wordCount = new BigDecimal(wordInCategoryCount(word, category) + 1);
BigDecimal categoryTrainFileWordCount = new BigDecimal(categoryWordCount(category) + getAllTrainCategoryWordCount());
return wordCount.divide(categoryTrainFileWordCount, 10, BigDecimal.ROUND_CEILING);
}

第四步计算给定文本的分类结果

/**
* 多项式朴素贝叶斯分类结果
* P(C_i|w_1,w_2...w_n) = P(w_1,w_2...w_n|C_i) * P(C_i) / P(w_1,w_2...w_n)
* = P(w_1|C_i) * P(w_2|C_i)...P(w_n|C_i) * P(C_i) / (P(w_1) * P(w_2) ...P(w_n))
* @param words
* @return
*/
public static Map<String, BigDecimal> classifyResult(Set<String> words){
Map<String, BigDecimal> resultMap = new HashMap<String, BigDecimal>();
//获取训练语料集所有的分类集合
Set<String> categorySet = allTrainFileSegsMap.keySet();
//循环计算每个类别的概率
for (String categorySetLabel : categorySet){
BigDecimal probability = new BigDecimal(1.0);
for (String word : words){
probability = probability.multiply(categoryConditionalProbability(categorySetLabel, word)).multiply(zoomFactor);
}
resultMap.put(categorySetLabel, probability.multiply(prioriProbability(categorySetLabel)));
}
return resultMap;
}

辅助函数

/**
* 对分类结果进行比较,得出概率最大的类
* @param classifyResult
* @return
*/
public static String getClassifyResultName(Map<String, BigDecimal> classifyResult){
String classifyName = "";
if (classifyResult.isEmpty()){
return classifyName;
}
BigDecimal result = new BigDecimal(0);
Set<String> classifyResultSet = classifyResult.keySet();
for (String classifyResultSetString : classifyResultSet){
if (classifyResult.get(classifyResultSetString).compareTo(result) >= 1){
result = classifyResult.get(classifyResultSetString);
classifyName = classifyResultSetString;
}
}
return classifyName;
} /**
* 统计给定类别下的单词总数(带词频计算)
* @param categoryLabel 指定类别参数
* @return
*/
public static Long categoryWordCount(String categoryLabel){
Long sum = 0L;
Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel);
if (categoryWordMap == null){
return sum;
}
Set<String> categoryWordMapKeySet = categoryWordMap.keySet();
for (String categoryLabelString : categoryWordMapKeySet){
Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryLabelString);
List<Map.Entry<String, Integer>> dataWordMapList = new ArrayList<Map.Entry<String, Integer>>(categoryWordMapDataMap.entrySet());
for (int i=0; i<dataWordMapList.size(); i++){
sum += dataWordMapList.get(i).getValue();
}
}
return sum;
} /**
* 获取训练样本所有词的总数(词总数计算是带上词频的,也就是可以重复算数)
* @return
*/
public static Long getAllTrainCategoryWordsCount(){
Long sum = 0L;
//获取所有分类
Set<String> categoryLabels = allTrainFileSegsMap.keySet();
//循环相加每个类下的词总数
for (String categoryLabel : categoryLabels){
sum += categoryWordCount(categoryLabel);
}
return sum;
} /**
* 获取训练样本下各个类别不重复词的总词数,区别于getAllTrainCategoryWordsCount()方法,此处计算不计算词频
* 备注:此处并不是严格意义上的进行全量词表生成后的计算,也就是加入类别1有"中国=6"、类别2有"中国=2",总词数算中国两次,
* 也就是说,我们在计算的时候并没有生成全局词表(将所有词都作为出现一次)
* @return
*/
public static Long getAllTrainCategoryWordCount(){
Long sum = 0L;
//获取所有分类
Set<String> categoryLabels = allTrainFileSegsMap.keySet();
for (String cateGoryLabelsLabel : categoryLabels){
Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(cateGoryLabelsLabel);
List<Map.Entry<String, Map<String, Integer>>> categoryWordMapList = new ArrayList<Map.Entry<String, Map<String, Integer>>>(categoryWordMap.entrySet());
for (int i=0; i<categoryWordMapList.size(); i++){
sum += categoryWordMapList.get(i).getValue().size();
}
}
return sum;
} /**
* 计算测试数据的每个单词在每个类下出现的总数
* @param word
* @param categoryLabel
* @return
*/
public static Long wordInCategoryCount(String word, String categoryLabel){
Long sum = 0L;
Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel);
Set<String> categoryWordMapKeySet = categoryWordMap.keySet();
for (String categoryWordMapKeySetFile : categoryWordMapKeySet){
Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryWordMapKeySetFile);
Integer value = categoryWordMapDataMap.get(word);
if (value!=null && value>0){
sum += value;
}
}
return sum;
} /**
* 获取所有分类类别
* @return
*/
public Set<String> getAllCategory(){
return allTrainFileSegsMap.keySet();
}

main函数测试

//main方法
public static void main(String[] args){
BayesNB.getFeatureClassForTrainText("/Users/zhouyh/work/yanfa/xunlianji/train/");
String s = "全国假日旅游部际协调会议的各成员单位和*各有关部门围绕一个目标,积极配合,主动工作,抓得深入,抓得扎实。主要有以下几个特点:一是安全工作有部署有检查有跟踪。国务院安委会办公室节前深入部署全面检查,节中及时总结,下发关于黄金周后期安全工作的紧急通知;铁路、民航、交通等部门针对黄金周前后期旅客集中返程交通压力较大情况,及时调遣应急运力;质检总局进一步强化节日期间质量安全监管工作;旅游部门每日及时发布旅游信息通报,有效引导游客。二是各方面主动协调密切配合。各省区市加强了在安全事故问题上的协调与沟通,化解了一些跨省区矛盾和问题;铁道、民航部门准时准确报送信息;中宣部和*文明办以黄金周旅游为载体," +
"部署精神文明建设和践行*荣辱观的宣传活动;中国气象局及时将黄金周每日气象分析送交各有关部门;*部专门部署警力,为协调游客流动大的城市及景区做了大量工作;旅游部门密切配合有关部门做好各类事故处理和投诉调解工作。三是*各部门的社会服务意识大为增强。外交部及其驻外领事馆及时提供*安全信息为旅游者服务;*电视台、地方电视台和各大媒体及各地方媒体提供的旅游信息十分丰富;气象信息服务充分具体;中消协提出多项旅游警示。各部门的密切配合和主动服务配合,确保了本次黄金周的顺利平稳运行。";
Set<String> words = IKWordSegmentation.segString(s).keySet(); Map<String, BigDecimal> resultMap = BayesNB.classifyResult(words);
String category = BayesNB.getClassifyResultName(resultMap);
System.out.println(category);
}

经过上述步骤即可实现简单的多项式模型算法,有部分代码参考了网上的算法代码。