机器学习实战朴素贝叶斯的java实现

时间:2022-12-14 14:01:09
package com.haolidong.Bayes;

import java.util.ArrayList;

/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class Matrix {
public ArrayList<ArrayList<String>> data;

public Matrix() {
// TODO Auto-generated constructor stub
data = new ArrayList<ArrayList<String>>();
}
}
package com.haolidong.Bayes;import java.util.ArrayList;/** *  * @author haolidong * @Description: [该类主要用于保存特征信息以及标签值] * @parameter labels: [主要保存标签值] */public class CreateDataSet extends Matrix {	public ArrayList<String> labels;	public CreateDataSet() {		// TODO Auto-generated constructor stub		super();		labels = new ArrayList<String>();	}	/**	 * @author haolidong	 * @Description: [机器学习实战决策树第一个案例的数据]	 */	public void initTest() {		ArrayList<String> ab1 = new ArrayList<String>();		ArrayList<String> ab2 = new ArrayList<String>();		ArrayList<String> ab3 = new ArrayList<String>();		ArrayList<String> ab4 = new ArrayList<String>();		ArrayList<String> ab5 = new ArrayList<String>();		ArrayList<String> ab6 = new ArrayList<String>();		ab1.add("my");		ab1.add("dog");		ab1.add("has");		ab1.add("flea");		ab1.add("problems");		ab1.add("help");		ab1.add("please");		ab2.add("maybe");		ab2.add("not");		ab2.add("take");		ab2.add("him");		ab2.add("to");		ab2.add("dog");		ab2.add("park");		ab2.add("stupid");		ab3.add("my");		ab3.add("dalmation");		ab3.add("is");		ab3.add("so");		ab3.add("cute");		ab3.add("I");		ab3.add("love");		ab3.add("him");		ab4.add("stop");		ab4.add("posting");		ab4.add("stupid");		ab4.add("worthless");		ab4.add("garbage");		ab5.add("mr");		ab5.add("licks");		ab5.add("ate");		ab5.add("my");		ab5.add("steak");		ab5.add("how");		ab5.add("to");		ab5.add("stop");		ab5.add("him");		ab6.add("quit");		ab6.add("buying");		ab6.add("worthless");		ab6.add("dog");		ab6.add("food");		ab6.add("stupid");		data.add(ab1);		data.add(ab2);		data.add(ab3);		data.add(ab4);		data.add(ab5);		data.add(ab6);		labels.add("0");		labels.add("1");		labels.add("0");		labels.add("1");		labels.add("0");		labels.add("1");	}}
package com.haolidong.Bayes;import java.util.ArrayList;/** *  * @parameter p0Vect 类别0的特征向量(概率向量) * @parameter p1Vect 类别1的特征向量(概率向量) * @parameter pAbusive 正样本(为1的样本)的比例 * @author haolidong   * @Description: [该类主要用于保存特征信息] * @parameter data: [主要保存特征矩阵] */public class TrainNB0DataSet {	public ArrayList<Double> p0Vect;	public ArrayList<Double> p1Vect;	public double pAbusive;	public TrainNB0DataSet() {		p0Vect = new ArrayList<Double>();		p1Vect = new ArrayList<Double>();		pAbusive = 0.0;	}}

package com.haolidong.Bayes;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.HashSet;public class Bayes {	public static void main(String[] args) {		spamTest();	}	/**	 * @param end  从0到end的范围中产生num个不重复的随机数	 * @param num  num个随机数	 * @return 返回产生的n个随机数	 * @author haolidong	 * @Description: [从0到end的范围中产生num个不重复的随机数]	 */	public static HashSet<Integer> randomdif(int end,int num){		HashSet<Integer> rndint = new HashSet<Integer>();		rndint.size();		while ( rndint.size() < num ) {			rndint.add((int) (Math.random()*end));		}		return rndint;	}	/**	 * @author haolidong	 * @Description: [垃圾邮件分类测试]	 */	public static void spamTest(){		ArrayList<String> fullText = new ArrayList<String>();		CreateDataSet DataSet = new CreateDataSet();		for (int i = 1; i < 26; i++) {			ArrayList<String> hamWordList = new ArrayList<String>();			ArrayList<String> spamWordList = new ArrayList<String>();			String hamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\ham\\"+i+".txt");			String spamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\spam\\"+i+".txt");			hamWordList = textParse(spamPath, 2);			DataSet.data.add(hamWordList);			DataSet.labels.add("1");			for (int j = 0; j < hamWordList.size(); j++) {				fullText.add(hamWordList.get(j));			}			spamWordList=textParse(hamPath, 2);			DataSet.data.add(spamWordList);			DataSet.labels.add("0");			for (int j = 0; j < spamWordList.size(); j++) {				fullText.add(spamWordList.get(j));			}		}		//获取词典		HashSet<String> vocabList = new HashSet<String>();		vocabList = createVocabList(DataSet);		HashSet<Integer> rndint = new HashSet<Integer>();		//随机产生10个测试集,其余的为训练集		rndint = randomdif(50,10);		Matrix testMatrix = new Matrix();		Matrix trainMatrix = new Matrix();		ArrayList<String> trainLabels = new ArrayList<String>();		ArrayList<String> testLabels = new ArrayList<String>();		Matrix testMatrixTrans = new Matrix();		Matrix trainMatrixTrans = new Matrix();		for(Integer i:rndint){			testMatrix.data.add(DataSet.data.get(i));			testLabels.add(DataSet.labels.get(i));		}		for (int i = 0; i < DataSet.data.size(); i++) {			if(!rndint.contains(i)){				trainMatrix.data.add(DataSet.data.get(i));				trainLabels.add(DataSet.labels.get(i));			}		}		//转化到0 1矩阵		for (int i = 0; i < trainMatrix.data.size(); i++) {			trainMatrixTrans.data.add(setOfWords2Vec(vocabList,trainMatrix.data.get(i)));		}		for (int i = 0; i < testMatrix.data.size(); i++) {			testMatrixTrans.data.add(setOfWords2Vec(vocabList,testMatrix.data.get(i)));		}		//训练集的训练		TrainNB0DataSet td = new TrainNB0DataSet();		td = trainNB0(trainMatrixTrans,trainLabels);		//对测试集进行测试		int errorCount=0;		for (int i = 0; i < testMatrixTrans.data.size(); i++) {			int num=classifyNB(testMatrixTrans.data.get(i), td.p0Vect, td.p1Vect, td.pAbusive);			System.out.println("the predict:"+num+" , the real:"+testLabels.get(i));			if(num!=Integer.parseInt(testLabels.get(i))){				errorCount++;			}		}		System.out.println("the errorRate is:"+1.0*errorCount/testMatrixTrans.data.size());	}	public static ArrayList<String> textParse(String fileName,int moreThan){		ArrayList<String> strSplitList = new ArrayList<String>();		String s = readFile(fileName);		strSplitList = extractStrlist(s,moreThan);		return strSplitList;			}	/**	 * @param fileName  输入的完整文件路径	 * @return 所有的文件内容的字符串	 * @author haolidong	 * @Description: [一行一行读取文件,然后用字符串全部串起来返回,每一行之间使用空格分割]	 */	public static String readFile(String fileName) {		File file = new File(fileName);		BufferedReader reader = null;		String s = new String();		try {			reader = new BufferedReader(new FileReader(file));			String tempString = null;			// 一次读入一行,直到读入null为文件结束			while ((tempString = reader.readLine()) != null) {				//加上" "是为了和下面一段的字符进行区分				s=s+tempString+" ";			}			reader.close();		} catch (IOException e) {			e.printStackTrace();		} finally {			if (reader != null) {				try {					reader.close();				} catch (IOException e1) {				}			}		}		return s;	}	/**	 * @param inputString 输入的字符串	 * @param moreThan    只有超过moreThan的字符串才会被保留	 * @return    分割好的数据串	 * @author haolidong	 * @Description: [读取一个字符串,进行分割,去掉除了字母数字以外的字符数组,而且所有的字符都改成小写]	 */	public static ArrayList<String> extractStrlist(String inputString,int moreThan) {		ArrayList<String> strSplitList = new ArrayList<String>();		String regEx = "\\W*";		String sentence="";//		String inputString = "This book is the best book on M.L. I have";		String[] predel = inputString.split(regEx);		for (int i = 0; i < predel.length; i++) {			if(predel[i].equals(""))				sentence+=" ";			else				sentence+=predel[i];		}		String[] strSplit=sentence.split(" ");		for (int i = 0; i < strSplit.length; i++) {			if(strSplit[i].length()>moreThan) {				strSplitList.add(strSplit[i].toLowerCase());			}				}		return strSplitList;	}	/**	 * @param vec2Classify   需要进行分类的向量	 * @param p0Vec          类别0的权值向量	 * @param p1Vec          类别1的权值向量	 * @param pClass1                            类别1所占的比重	 * @return               返回最后的分类结果	 * @author haolidong     	 * @Description: [计算在每一类中最后的概率返回最大的所对应的标签]	 */	public static int classifyNB(ArrayList<String> vec2Classify, ArrayList<Double> p0Vec, ArrayList<Double> p1Vec,			double pClass1) {		double p1 = 0.0;		double p0 = 0.0;		for (int i = 0; i < vec2Classify.size(); i++) {			p1 = p1 + Double.parseDouble(vec2Classify.get(i)) * p1Vec.get(i);			p0 = p0 + Double.parseDouble(vec2Classify.get(i)) * p0Vec.get(i);		}		p1 = p1 + Math.log(pClass1);		p0 = p0 + Math.log(1 - pClass1);		if (p1 > p0)			return 1;		else			return 0;	}	/**	 * @param trainMatrix      训练矩阵	 * @param trainCategory    训练目录标签	 * @return                 返回最后训练结果,包括每一类的特征矩阵以及每一类的比重情况	 * @author haolidong     	 * @Description: [贝叶斯分类的重点函数,数据集的训练,返回特征矩阵和向量]	 */	public static TrainNB0DataSet trainNB0(Matrix trainMatrix, ArrayList<String> trainCategory) {		int numTrainDocs = trainMatrix.data.size();		int numWords = trainMatrix.data.get(0).size();		TrainNB0DataSet resultSet = new TrainNB0DataSet();		ArrayList<Double> p0Num = new ArrayList<Double>();		ArrayList<Double> p1Num = new ArrayList<Double>();		double trainCategorySum = 0.0;		for (int i = 0; i < trainCategory.size(); i++) {			trainCategorySum = trainCategorySum + Double.parseDouble(trainCategory.get(i));		}		resultSet.pAbusive = trainCategorySum / numTrainDocs;		for (int i = 0; i < numWords; i++) {			p0Num.add(1.0);			p1Num.add(1.0);		}		double p0Denom = 2.0;		double p1Denom = 2.0;		for (int i = 0; i < numTrainDocs; i++) {			if (trainCategory.get(i).equals("1")) {				for (int j = 0; j < numWords; j++) {					p1Num.set(j, p1Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));				}			} else {				for (int j = 0; j < numWords; j++) {					p0Num.set(j, p0Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));				}			}		}		for (int i = 0; i < numWords; i++) {			p0Denom += p0Num.get(i);			p1Denom += p1Num.get(i);		}		p0Denom = p0Denom - numWords;		p1Denom = p1Denom - numWords;		for (int i = 0; i < numWords; i++) {			resultSet.p0Vect.add(Math.log(p0Num.get(i) / p0Denom));			resultSet.p1Vect.add(Math.log(p1Num.get(i) / p1Denom));		}		return resultSet;	}	/**	 * @param vocabSet       字典	 * @param inputSet       输入数据集	 * @return               返回与字典一一对应的数据集	 * @author haolidong     	 * @Description: [生成一个全部为0的字典,把字典中数据集中有的字符串设置为1,其他的设置为0,返回设置完的字典]	 */	public static ArrayList<String> setOfWords2Vec(HashSet<String> vocabSet, ArrayList<String> inputSet) {		ArrayList<String> returnVec = new ArrayList<String>();		boolean flag;		for (String value : vocabSet) {			flag = false;			for (int i = 0; i < inputSet.size(); i++) {				if (inputSet.get(i).equals(value)) {					returnVec.add("1");					flag = true;					break;				}			}			if (flag == false) {				returnVec.add("0");			}		}		return returnVec;	}	/**	 * @param dataSet    输入数据集	 * @return           字典	 * @author haolidong     	 * @Description: [输入数据集,数据有比较大的重复,然后去掉重复的数据,最后生成字典]	 */	public static HashSet<String> createVocabList(Matrix dataSet) {		HashSet<String> vocabSet = new HashSet<String>();		for (int i = 0; i < dataSet.data.size(); i++) {			for (int j = 0; j < dataSet.data.get(i).size(); j++) {				vocabSet.add(dataSet.data.get(i).get(j));			}		}		return vocabSet;	}	/**	 * @author haolidong     	 * @Description: [对于生成字典功能的测试]	 */	public static void testVocabList() {		CreateDataSet dataSet = new CreateDataSet();		dataSet.initTest();		HashSet<String> vocabSet = new HashSet<String>();		vocabSet = createVocabList(dataSet);		System.out.println(vocabSet);	}	/**	 * @author haolidong     	 * @Description: [对于输入字符集转化成字典的测试]	 */	public static void testWord2Vec() {		CreateDataSet dataSet = new CreateDataSet();		dataSet.initTest();		HashSet<String> vocabSet = new HashSet<String>();		ArrayList<String> returnVec = new ArrayList<String>();		vocabSet = createVocabList(dataSet);		returnVec = setOfWords2Vec(vocabSet, dataSet.data.get(0));		System.out.println(returnVec);	}	/**	 * @author haolidong     	 * @Description: [对于样本训练的测试]	 */	public static void testTrain() {		CreateDataSet dataSet = new CreateDataSet();		Matrix trainMatrix = new Matrix();		dataSet.initTest();		HashSet<String> vocabSet = new HashSet<String>();		vocabSet = createVocabList(dataSet);		for (int i = 0; i < dataSet.data.size(); i++) {			trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));		}		trainNB0(trainMatrix, dataSet.labels);	}	/**	 * @author haolidong     	 * @Description: [对于样本分类的测试]	 */	public static void testingNB() {		CreateDataSet dataSet = new CreateDataSet();		TrainNB0DataSet td = new TrainNB0DataSet();		ArrayList<String> testEntry = new ArrayList<String>();		Matrix trainMatrix = new Matrix();		dataSet.initTest();		HashSet<String> vocabSet = new HashSet<String>();		vocabSet = createVocabList(dataSet);		for (int i = 0; i < dataSet.data.size(); i++) {			trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));		}		td = trainNB0(trainMatrix, dataSet.labels);		testEntry.add("love");		testEntry.add("my");		testEntry.add("dalmation");		testEntry = setOfWords2Vec(vocabSet, testEntry);		System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));		testEntry.clear();		testEntry.add("stupid");		testEntry.add("garbage");		testEntry = setOfWords2Vec(vocabSet, testEntry);		System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));	}}