1、概述
mahout0.9 对贝叶斯模型只提供了训练trainnb和测试testnb函数,仅能够得到模型和测试模型的好坏,没有实现模型预测功能,通过对mahout源码的解读,自己编写了mahout bayes模型的预测功能。mahout0.9贝叶斯的使用方式见http://blog.csdn.net/mach_learn/article/details/39667713
2、mahout不支持predict原因
mahout0.9将训练集合测试集同时进行序列化和向量化,然后再将向量化的文件进行分片,分为测试集合训练集。mahout在向量化时会生成以下文件
其中,dictionary.file-0文件将词对应到整形序号,key对应词或标点符号等,value代表序号值(整数)。frequency.file-0的key值对应序号,value值为key序号对应的词在多少文件中出现。df-count文件夹存放的是document frenquency的数据。tf-vectors中存放的是每个文件的term frenquency。tfidf-vectors中存放的是每个文件中词序号和对应的tfidf值。tokenized-documents中存放的是分词后的文件。wordcount存放的是每个词在全部文档中的词频。
mahout向量化结束后将tfidf-vectors中的文件进行分片,分为训练集和测试集,一般是80-20比例,然后使用trainnb对训练集训练得到naiveBayesModel.bin模型,之后再使用testnb对naiveBayesModel.bin模型进行测试评估。
mahout进行统一向量化后会有一个统一的dictionary文件,这就导致了其他单独通过seq2sparse进行向量化的文件时不能使用其他训练数据得到的naiveBayesModel.bin模型,因为两个向量的dictionary是不一样的。
3、mahout预测函数编写思路
为了使用naiveBayesModel.bin模型进行预测,我们需要将需要预测的数据根据使用模型的向量化标准进行处理(即要使预测数据与产生向量时的dictionary等文件对于起来)。首先,将预测数据对应到相应对的dictionary,然后,根据对应词的序号获取df-count数据,之后计算该数据对应的tfidf数据(计算tfidf仅使用df-count和numdocs,以及预测数据的词频),numdocs是df-count中key为-1对应的value值。将tfidf数据代入naiveBayesModel.bin模型,即可求得每种类别对应的似然值,取最大值对应的类别,即是预测类别。
编程环境,mahout0.9,需要的jar包见下图
程序需要mahout训练的模型和seq2sparse向量化的文件。seq2sparse向量化的文件需要使用mahout seqdumper -i inputfile -o outputfile命令,将序列化文件转为文本文件。文件结构图如下:
代码如下:
import java.awt.print.Printable;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.shell.Count;
import org.apache.hadoop.hdfs.server.namenode.status_jsp;
import org.apache.hadoop.mapred.ID;
import org.apache.mahout.cf.taste.hadoop.als.PredictionMapper;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.MutableElement;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.mahout.vectorizer.TFIDF.*;
import com.ibm.icu.impl.ICUService.Key;
public class BayesPredict extends AbstractJob
{
public static HashMap<String, String> dictionaryHashMap = new HashMap<>();
public static HashMap<String, String> dfcountHashMap = new HashMap<>();
public static HashMap<String, String> wordcountHashMap = new HashMap<>();
public static HashMap<String, String> labelindexHashMap = new HashMap<>();
public BayesPredict()
{
readDfCount("model/df-count.txt");
readDictionary("model/dictionary.txt");
readLabelIndex("model/labelindex.txt");
readWordCount("model/wordcount.txt");
}
public static String[] readFile(String filename)
{
File file = new File(filename);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
tempstring = reader.readLine();
reader.close();
if(tempstring==null)
return null;
}
catch (IOException e)
{
e.printStackTrace();
}
String[] mess = tempstring.trim().split(" ");
return mess;
}
public static void readDictionary(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{
if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
dictionaryHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readDfCount(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{
if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
dfcountHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readWordCount(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{
if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
wordcountHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readLabelIndex(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{
if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
labelindexHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static HashMap<Integer, Double> calcTfIdf(String filename)
{
String[] words = readFile(filename);
if(words==null)
return null;
HashMap<Integer, Double> tfidfHashMap = new HashMap<Integer, Double>();
HashMap<String, Integer> wordHashMap = new HashMap<String, Integer>();
for(int k=0; k<words.length; k++)
{
if(wordHashMap.get(words[k])==null)
{
wordHashMap.put(words[k], 1);
}
else
{
wordHashMap.put(words[k], wordHashMap.get(words[k])+1);
}
}
//System.out.println("wordcount:"+wordHashMap.size());
/*
System.out.println("dfcount:"+dfcountHashMap.size());
System.out.println("dictionary:"+dictionaryHashMap.size());
System.out.println("labelindex:"+labelindexHashMap.size());
System.out.println("wordcount:"+wordcountHashMap.size());
*/
Iterator iterator = wordHashMap.entrySet().iterator();
int numDocs = Integer.parseInt(dfcountHashMap.get("-1"));
while(iterator.hasNext())
{
Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)iterator.next();
String key = entry.getKey();
int value = entry.getValue();
int tf = value;
//System.out.println(key+":"+value);
if(dictionaryHashMap.get(key)!=null)
{
String idString = dictionaryHashMap.get(key);
int df = Integer.parseInt(dfcountHashMap.get(idString));
TFIDF tfidf = new TFIDF();
double tfidf_value = tfidf.calculate(tf, df, 0, numDocs);
tfidfHashMap.put(Integer.parseInt(idString), tfidf_value);
//System.out.println(idString+":"+tfidf_value);
}
}
return tfidfHashMap;
}
public String predict(String filename) throws IOException
{
HashMap<Integer, Double> tfidfHashMap = calcTfIdf(filename);
if(tfidfHashMap==null)
return "file is empty,unknow classify";
//FileSystem fs = FileSystem.get(getConf());
NaiveBayesModel model = NaiveBayesModel.materialize(new Path("model/model/"), getConf());
ComplementaryNaiveBayesClassifier classifier;
classifier = new ComplementaryNaiveBayesClassifier(model);
double label_1=0;
double label_2=0;
Iterator iterator = tfidfHashMap.entrySet().iterator();
while(iterator.hasNext())
{
Map.Entry<Integer, Double> entry = (Map.Entry<Integer, Double>)iterator.next();
int key = entry.getKey();
double value = entry.getValue();
label_1 += value*classifier.getScoreForLabelFeature(0, key);
label_2 += value*classifier.getScoreForLabelFeature(1, key);
}
//System.out.println("label_1:"+label_1);
//System.out.println("label_2:"+label_2);
if(label_1>label_2)
return "fraud-female";
else
return "norm-female";
}
@Override
public int run(String[] arg0) throws Exception {
// TODO Auto-generated method stub
return 0;
}
public static void main(String[] args)
{
//dictionary test
/*
readDictionary("model/dictionary.txt");
Iterator iterator = dictionaryHashMap.entrySet().iterator();
while(iterator.hasNext())
{
Map.Entry<String, String> entry = (Map.Entry<String, String>)iterator.next();
System.out.println(entry.getKey()+"--"+entry.getValue());
}
System.out.println(dictionaryHashMap.size());
System.out.println(System.getProperty("user.dir"));
*/
long startTime=System.currentTimeMillis();
BayesPredict bPredict = new BayesPredict();
try {
File file = new File("model/test/");
String[] filenames = file.list();
int count1 = 0;
int count2 = 0;
int count = 0;
for(int i=0;i<filenames.length;i++)
{
String result = bPredict.predict("model/test/"+filenames[i]);
count++;
if(result.equals("fraud-female"))
count1++;
else if(result.equals("norm-female"))
count2++;
System.out.println(filenames[i]+":"+result);
}
System.out.println("count:"+count);
System.out.println("count1:"+count1);
System.out.println("count2:"+count2);
System.out.println("time:"+(System.currentTimeMillis()-startTime)/1000.0);
} catch (IOException e) {
e.printStackTrace();
}
}
}