package com.haolidong.Bayes;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class Matrix {
public ArrayList<ArrayList<String>> data;
public Matrix() {
// TODO Auto-generated constructor stub
data = new ArrayList<ArrayList<String>>();
}
}
package com.haolidong.Bayes;import java.util.ArrayList;/** * * @author haolidong * @Description: [该类主要用于保存特征信息以及标签值] * @parameter labels: [主要保存标签值] */public class CreateDataSet extends Matrix { public ArrayList<String> labels; public CreateDataSet() { // TODO Auto-generated constructor stub super(); labels = new ArrayList<String>(); } /** * @author haolidong * @Description: [机器学习实战决策树第一个案例的数据] */ public void initTest() { ArrayList<String> ab1 = new ArrayList<String>(); ArrayList<String> ab2 = new ArrayList<String>(); ArrayList<String> ab3 = new ArrayList<String>(); ArrayList<String> ab4 = new ArrayList<String>(); ArrayList<String> ab5 = new ArrayList<String>(); ArrayList<String> ab6 = new ArrayList<String>(); ab1.add("my"); ab1.add("dog"); ab1.add("has"); ab1.add("flea"); ab1.add("problems"); ab1.add("help"); ab1.add("please"); ab2.add("maybe"); ab2.add("not"); ab2.add("take"); ab2.add("him"); ab2.add("to"); ab2.add("dog"); ab2.add("park"); ab2.add("stupid"); ab3.add("my"); ab3.add("dalmation"); ab3.add("is"); ab3.add("so"); ab3.add("cute"); ab3.add("I"); ab3.add("love"); ab3.add("him"); ab4.add("stop"); ab4.add("posting"); ab4.add("stupid"); ab4.add("worthless"); ab4.add("garbage"); ab5.add("mr"); ab5.add("licks"); ab5.add("ate"); ab5.add("my"); ab5.add("steak"); ab5.add("how"); ab5.add("to"); ab5.add("stop"); ab5.add("him"); ab6.add("quit"); ab6.add("buying"); ab6.add("worthless"); ab6.add("dog"); ab6.add("food"); ab6.add("stupid"); data.add(ab1); data.add(ab2); data.add(ab3); data.add(ab4); data.add(ab5); data.add(ab6); labels.add("0"); labels.add("1"); labels.add("0"); labels.add("1"); labels.add("0"); labels.add("1"); }}
package com.haolidong.Bayes;import java.util.ArrayList;/** * * @parameter p0Vect 类别0的特征向量(概率向量) * @parameter p1Vect 类别1的特征向量(概率向量) * @parameter pAbusive 正样本(为1的样本)的比例 * @author haolidong * @Description: [该类主要用于保存特征信息] * @parameter data: [主要保存特征矩阵] */public class TrainNB0DataSet { public ArrayList<Double> p0Vect; public ArrayList<Double> p1Vect; public double pAbusive; public TrainNB0DataSet() { p0Vect = new ArrayList<Double>(); p1Vect = new ArrayList<Double>(); pAbusive = 0.0; }}
package com.haolidong.Bayes;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.HashSet;public class Bayes { public static void main(String[] args) { spamTest(); } /** * @param end 从0到end的范围中产生num个不重复的随机数 * @param num num个随机数 * @return 返回产生的n个随机数 * @author haolidong * @Description: [从0到end的范围中产生num个不重复的随机数] */ public static HashSet<Integer> randomdif(int end,int num){ HashSet<Integer> rndint = new HashSet<Integer>(); rndint.size(); while ( rndint.size() < num ) { rndint.add((int) (Math.random()*end)); } return rndint; } /** * @author haolidong * @Description: [垃圾邮件分类测试] */ public static void spamTest(){ ArrayList<String> fullText = new ArrayList<String>(); CreateDataSet DataSet = new CreateDataSet(); for (int i = 1; i < 26; i++) { ArrayList<String> hamWordList = new ArrayList<String>(); ArrayList<String> spamWordList = new ArrayList<String>(); String hamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\ham\\"+i+".txt"); String spamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\spam\\"+i+".txt"); hamWordList = textParse(spamPath, 2); DataSet.data.add(hamWordList); DataSet.labels.add("1"); for (int j = 0; j < hamWordList.size(); j++) { fullText.add(hamWordList.get(j)); } spamWordList=textParse(hamPath, 2); DataSet.data.add(spamWordList); DataSet.labels.add("0"); for (int j = 0; j < spamWordList.size(); j++) { fullText.add(spamWordList.get(j)); } } //获取词典 HashSet<String> vocabList = new HashSet<String>(); vocabList = createVocabList(DataSet); HashSet<Integer> rndint = new HashSet<Integer>(); //随机产生10个测试集,其余的为训练集 rndint = randomdif(50,10); Matrix testMatrix = new Matrix(); Matrix trainMatrix = new Matrix(); ArrayList<String> trainLabels = new ArrayList<String>(); ArrayList<String> testLabels = new ArrayList<String>(); Matrix testMatrixTrans = new Matrix(); Matrix trainMatrixTrans = new Matrix(); for(Integer i:rndint){ testMatrix.data.add(DataSet.data.get(i)); testLabels.add(DataSet.labels.get(i)); } for (int i = 0; i < DataSet.data.size(); i++) { if(!rndint.contains(i)){ trainMatrix.data.add(DataSet.data.get(i)); trainLabels.add(DataSet.labels.get(i)); } } //转化到0 1矩阵 for (int i = 0; i < trainMatrix.data.size(); i++) { trainMatrixTrans.data.add(setOfWords2Vec(vocabList,trainMatrix.data.get(i))); } for (int i = 0; i < testMatrix.data.size(); i++) { testMatrixTrans.data.add(setOfWords2Vec(vocabList,testMatrix.data.get(i))); } //训练集的训练 TrainNB0DataSet td = new TrainNB0DataSet(); td = trainNB0(trainMatrixTrans,trainLabels); //对测试集进行测试 int errorCount=0; for (int i = 0; i < testMatrixTrans.data.size(); i++) { int num=classifyNB(testMatrixTrans.data.get(i), td.p0Vect, td.p1Vect, td.pAbusive); System.out.println("the predict:"+num+" , the real:"+testLabels.get(i)); if(num!=Integer.parseInt(testLabels.get(i))){ errorCount++; } } System.out.println("the errorRate is:"+1.0*errorCount/testMatrixTrans.data.size()); } public static ArrayList<String> textParse(String fileName,int moreThan){ ArrayList<String> strSplitList = new ArrayList<String>(); String s = readFile(fileName); strSplitList = extractStrlist(s,moreThan); return strSplitList; } /** * @param fileName 输入的完整文件路径 * @return 所有的文件内容的字符串 * @author haolidong * @Description: [一行一行读取文件,然后用字符串全部串起来返回,每一行之间使用空格分割] */ public static String readFile(String fileName) { File file = new File(fileName); BufferedReader reader = null; String s = new String(); try { reader = new BufferedReader(new FileReader(file)); String tempString = null; // 一次读入一行,直到读入null为文件结束 while ((tempString = reader.readLine()) != null) { //加上" "是为了和下面一段的字符进行区分 s=s+tempString+" "; } reader.close(); } catch (IOException e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e1) { } } } return s; } /** * @param inputString 输入的字符串 * @param moreThan 只有超过moreThan的字符串才会被保留 * @return 分割好的数据串 * @author haolidong * @Description: [读取一个字符串,进行分割,去掉除了字母数字以外的字符数组,而且所有的字符都改成小写] */ public static ArrayList<String> extractStrlist(String inputString,int moreThan) { ArrayList<String> strSplitList = new ArrayList<String>(); String regEx = "\\W*"; String sentence="";// String inputString = "This book is the best book on M.L. I have"; String[] predel = inputString.split(regEx); for (int i = 0; i < predel.length; i++) { if(predel[i].equals("")) sentence+=" "; else sentence+=predel[i]; } String[] strSplit=sentence.split(" "); for (int i = 0; i < strSplit.length; i++) { if(strSplit[i].length()>moreThan) { strSplitList.add(strSplit[i].toLowerCase()); } } return strSplitList; } /** * @param vec2Classify 需要进行分类的向量 * @param p0Vec 类别0的权值向量 * @param p1Vec 类别1的权值向量 * @param pClass1 类别1所占的比重 * @return 返回最后的分类结果 * @author haolidong * @Description: [计算在每一类中最后的概率返回最大的所对应的标签] */ public static int classifyNB(ArrayList<String> vec2Classify, ArrayList<Double> p0Vec, ArrayList<Double> p1Vec, double pClass1) { double p1 = 0.0; double p0 = 0.0; for (int i = 0; i < vec2Classify.size(); i++) { p1 = p1 + Double.parseDouble(vec2Classify.get(i)) * p1Vec.get(i); p0 = p0 + Double.parseDouble(vec2Classify.get(i)) * p0Vec.get(i); } p1 = p1 + Math.log(pClass1); p0 = p0 + Math.log(1 - pClass1); if (p1 > p0) return 1; else return 0; } /** * @param trainMatrix 训练矩阵 * @param trainCategory 训练目录标签 * @return 返回最后训练结果,包括每一类的特征矩阵以及每一类的比重情况 * @author haolidong * @Description: [贝叶斯分类的重点函数,数据集的训练,返回特征矩阵和向量] */ public static TrainNB0DataSet trainNB0(Matrix trainMatrix, ArrayList<String> trainCategory) { int numTrainDocs = trainMatrix.data.size(); int numWords = trainMatrix.data.get(0).size(); TrainNB0DataSet resultSet = new TrainNB0DataSet(); ArrayList<Double> p0Num = new ArrayList<Double>(); ArrayList<Double> p1Num = new ArrayList<Double>(); double trainCategorySum = 0.0; for (int i = 0; i < trainCategory.size(); i++) { trainCategorySum = trainCategorySum + Double.parseDouble(trainCategory.get(i)); } resultSet.pAbusive = trainCategorySum / numTrainDocs; for (int i = 0; i < numWords; i++) { p0Num.add(1.0); p1Num.add(1.0); } double p0Denom = 2.0; double p1Denom = 2.0; for (int i = 0; i < numTrainDocs; i++) { if (trainCategory.get(i).equals("1")) { for (int j = 0; j < numWords; j++) { p1Num.set(j, p1Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j))); } } else { for (int j = 0; j < numWords; j++) { p0Num.set(j, p0Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j))); } } } for (int i = 0; i < numWords; i++) { p0Denom += p0Num.get(i); p1Denom += p1Num.get(i); } p0Denom = p0Denom - numWords; p1Denom = p1Denom - numWords; for (int i = 0; i < numWords; i++) { resultSet.p0Vect.add(Math.log(p0Num.get(i) / p0Denom)); resultSet.p1Vect.add(Math.log(p1Num.get(i) / p1Denom)); } return resultSet; } /** * @param vocabSet 字典 * @param inputSet 输入数据集 * @return 返回与字典一一对应的数据集 * @author haolidong * @Description: [生成一个全部为0的字典,把字典中数据集中有的字符串设置为1,其他的设置为0,返回设置完的字典] */ public static ArrayList<String> setOfWords2Vec(HashSet<String> vocabSet, ArrayList<String> inputSet) { ArrayList<String> returnVec = new ArrayList<String>(); boolean flag; for (String value : vocabSet) { flag = false; for (int i = 0; i < inputSet.size(); i++) { if (inputSet.get(i).equals(value)) { returnVec.add("1"); flag = true; break; } } if (flag == false) { returnVec.add("0"); } } return returnVec; } /** * @param dataSet 输入数据集 * @return 字典 * @author haolidong * @Description: [输入数据集,数据有比较大的重复,然后去掉重复的数据,最后生成字典] */ public static HashSet<String> createVocabList(Matrix dataSet) { HashSet<String> vocabSet = new HashSet<String>(); for (int i = 0; i < dataSet.data.size(); i++) { for (int j = 0; j < dataSet.data.get(i).size(); j++) { vocabSet.add(dataSet.data.get(i).get(j)); } } return vocabSet; } /** * @author haolidong * @Description: [对于生成字典功能的测试] */ public static void testVocabList() { CreateDataSet dataSet = new CreateDataSet(); dataSet.initTest(); HashSet<String> vocabSet = new HashSet<String>(); vocabSet = createVocabList(dataSet); System.out.println(vocabSet); } /** * @author haolidong * @Description: [对于输入字符集转化成字典的测试] */ public static void testWord2Vec() { CreateDataSet dataSet = new CreateDataSet(); dataSet.initTest(); HashSet<String> vocabSet = new HashSet<String>(); ArrayList<String> returnVec = new ArrayList<String>(); vocabSet = createVocabList(dataSet); returnVec = setOfWords2Vec(vocabSet, dataSet.data.get(0)); System.out.println(returnVec); } /** * @author haolidong * @Description: [对于样本训练的测试] */ public static void testTrain() { CreateDataSet dataSet = new CreateDataSet(); Matrix trainMatrix = new Matrix(); dataSet.initTest(); HashSet<String> vocabSet = new HashSet<String>(); vocabSet = createVocabList(dataSet); for (int i = 0; i < dataSet.data.size(); i++) { trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i))); } trainNB0(trainMatrix, dataSet.labels); } /** * @author haolidong * @Description: [对于样本分类的测试] */ public static void testingNB() { CreateDataSet dataSet = new CreateDataSet(); TrainNB0DataSet td = new TrainNB0DataSet(); ArrayList<String> testEntry = new ArrayList<String>(); Matrix trainMatrix = new Matrix(); dataSet.initTest(); HashSet<String> vocabSet = new HashSet<String>(); vocabSet = createVocabList(dataSet); for (int i = 0; i < dataSet.data.size(); i++) { trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i))); } td = trainNB0(trainMatrix, dataSet.labels); testEntry.add("love"); testEntry.add("my"); testEntry.add("dalmation"); testEntry = setOfWords2Vec(vocabSet, testEntry); System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive)); testEntry.clear(); testEntry.add("stupid"); testEntry.add("garbage"); testEntry = setOfWords2Vec(vocabSet, testEntry); System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive)); }}