Trie树的简单应用

时间:2021-01-12 20:32:47

简单语法匹配



import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;

public class SemanticTrie {
/**
*a adjective 美丽 ni organization name 保险公司
*b other noun-modifier 大型, 西式 nl location noun 城郊
*c conjunction 和, 虽然 ns geographical name 北京
*d adverb 很 nt temporal noun 近日, 明代
*e exclamation 哎 nz other proper noun 诺贝尔奖
*g morpheme 茨, 甥 o onomatopoeia 哗啦
*h prefix 阿, 伪 p preposition 在, 把
*i idiom 百花齐放 q quantity 个
*j abbreviation 公检法 r pronoun 我们
*k suffix 界, 率 u auxiliary 的, 地
*m number 一, 第一 v verb 跑, 学习
*n general noun 苹果 wp punctuation ,。!
*nd direction noun 右侧 ws foreign words CPU
*nh person name 杜甫, 汤姆 x non-lexeme 萄, 翱
**/
public static final String [] tags={"a","b","c","d","e","f","g","h","i","j","k","m","n","nd","nh","ni","nl","ns","nt","nz","o","p","q","r","u","v","wp","ws","x"};
public static SemanticTrie semanticTrie=new SemanticTrie();
static{
System.out.println("初始化语法");
semanticTrie.addVertex("a+b+c");
semanticTrie.addVertex("k+b+c+a");
semanticTrie.addVertex("a+nh+c+o");
semanticTrie.addVertex("nz+p+wp+a+c+d");
System.out.println("初始化结束");
}
/**
* 获得单个词性的索引
* @param tag
* @return
*/
public int getTagIndex(String tag){
if (tag==null||tag.isEmpty()) {
return -1;
}
for (int i = 0; i < tags.length; i++) {
if (tag.toLowerCase().equals(tags[i])) {
return i;
}
}
return -1;
}
/**
* 一个Trie树有一个根节点
*/
public Vertex root;
/**
* 节点内部类
*/
protected class Vertex {
/**
* 词性
*/
protected String tag;
/**
* 是否是完整的语法
*/
protected boolean isFull;
/**
* 前缀出现次数
*/
protected int count;
protected Vertex[] edges;// 每个节点包含29个子节点(类型为自身)

Vertex() {
tag="";
isFull=false;
count = 0;
edges = new Vertex[tags.length];
for (int i = 0; i < edges.length; i++) {
edges[i] = null;
}
}
}
public SemanticTrie(){root=new Vertex();}
/**
* 添加语法
* @param grammar 语法以“+”分割 例: a+b+c
*/
public void addVertex(String grammar){
if(grammar==null||grammar.trim().isEmpty()){return;}
String[]grammars=grammar.split("\\+");
Queue<String> grammarQueue=new LinkedList<String>();
for (String str : grammars) {
if (!str.trim().isEmpty()) {
grammarQueue.add(str);
}
}
this.addVertex(root,grammarQueue);
}
/**
* 添加语法节点
* @param vertex
* @param grammarQueue
*/
public void addVertex(Vertex vertex,Queue<String> grammarQueue){
if (grammarQueue.size()<=0) {
//设为是完整的语法
vertex.isFull=true;
}else {
String tag=grammarQueue.poll();
int index=this.getTagIndex(tag);
if (index<0) {
addVertex(vertex, grammarQueue);
}else {
if (vertex.edges[index]==null) {
vertex.edges[index]=new Vertex();
vertex.edges[index].tag=tag;
}else {
vertex.edges[index].tag=tag;
}
//出现次数
vertex.count++;
//调用下一个
addVertex(vertex.edges[index], grammarQueue);

}
}
}
/**
* 深度遍历所有存在的语法
* @param grammarList 结果list
* @param vertex 节点
* @param segment 节点值片段
*/
private void depthFirstAll(List<String> grammarList,Vertex vertex,String segment){
Vertex[] edges = vertex.edges;
for (int i = 0; i < edges.length; i++) {
if (edges[i] != null) {
String temSeg = segment + edges[i].tag+"+";
if (edges[i].isFull==true) {
grammarList.add(temSeg.substring(0,temSeg.length()-1));
}
depthFirstAll(grammarList, edges[i], temSeg);
}
}
}
/**
* 获得所有的语法
*/
public List<String> getAllGrammar(){
List<String> grammarList=new ArrayList<String>();
depthFirstAll(grammarList, root, "");
return grammarList;
}
/**
* 获得符合的语法《string》
* @param grammar
* @return
*/
public Set<String> getMatchGrammar(String grammar){
if(grammar==null||grammar.trim().isEmpty()){return null;}
String[]grammars=grammar.split("\\+");
Queue<String> grammarQueue=new LinkedList<String>();
for (String str : grammars) {
if (!str.trim().isEmpty()) {
grammarQueue.add(str);
}
}
Set<String> grammarList=new HashSet<String>();
selectMatchGrammar(grammarList, root, grammarQueue, "");
return grammarList;
}
/**
* 搜索符合的语法《string》
* @param grammarList 结果列表
* @param vertex 开始节点
* @param tempQuery 词性队列
*/
public void selectMatchGrammar(Set<String> grammarList,Vertex vertex,Queue<String> tempQuery,String segment){
if (tempQuery==null||tempQuery.size()<=0) {
return ;
}
System.out.println(tempQuery.toString());
while (tempQuery.size()>0) {
String tag=tempQuery.poll();
int index=this.getTagIndex(tag);
if (index<0) {continue;}
if (vertex.edges[index]==null) {continue;}
Vertex temV=vertex.edges[index];
//语法
String temSeg=segment+tag+"+";
if (temV.isFull==true) {
grammarList.add(temSeg.substring(0,temSeg.length()-1));
}
Queue<String> temQ=copy(tempQuery);
selectMatchGrammar(grammarList,temV,temQ,temSeg);
}
}
/**
* 复制对象《string》
* @param fromQueue
* @param toQueue
*/
public Queue<String> copy(Queue<String> fromQueue){
try {
if (fromQueue==null) {
return null;
}
Queue<String> toQueue=new LinkedList<String>();
Object[] array=fromQueue.toArray();
for (Object object : array) {
toQueue.add((String)object);
}
return toQueue;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
public static void main(String[] args) {
List<String> grammarLista=SemanticTrie.semanticTrie.getAllGrammar();
for (String string : grammarLista) {
System.out.println(string);
}
System.out.println("-----------------------");
Set<String> grammarList=SemanticTrie.semanticTrie.getMatchGrammar("a+b+nh+c+h+o+a+wp+d+b");
for (String string : grammarList) {
System.out.println(string);
}
}

}