package com.dianping.text.classify.util;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
public class TersorflowUtils {
-
private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();
public static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
public static byte[] readAllBytes(String path) {
try {
InputStream in = new FileInputStream(path);
byte[] bytes = new byte[in.available()];
in.read(bytes);
in.close();
return bytes;
} catch (Exception e) {
return null;
}
}
/*
* 序列默人长度为300
*/
public static int[][] gettexttoid(String text) {
int[][] xpad = new int[1][300];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (word_to_id.containsKey(element)) {
list.add(word_to_id.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= 300) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
public static int[][] gettexttoid(String text, Map<String, Integer> map) {
int[][] xpad = new int[1][300];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (map.containsKey(element)) {
list.add(map.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= 300) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
/*
* 自定义长度
*/
public static int[][] gettexttoid(String text, int maxlen) {
if (maxlen < 1) {
throw new IllegalArgumentException("maxlen长度必须大于等于1");
}
int[][] xpad = new int[1][maxlen];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (word_to_id.containsKey(element)) {
list.add(word_to_id.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= maxlen) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
public static int[][] gettexttoid(String text, int maxlen, Map<String, Integer> map) {
if (maxlen < 1) {
throw new IllegalArgumentException("maxlen长度必须大于等于1");
}
int[][] xpad = new int[1][maxlen];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (map.containsKey(element)) {
list.add(map.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= maxlen) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
private static int[] Intetoint(Integer[] arr) {
int[] result = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
result[i] = arr[i].intValue();
}
return result;
}
public static double getClassifyByBiLSTM(String text, Session sess, Map<String, Integer> map, Tensor keep_prob) {
if (StringUtils.isBlank(text)) {
return 0.0;
}
int[][] arr = gettexttoid(text, map);
Tensor input = Tensor.create(arr);
Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", keep_prob).fetch("score/pred_y").run()
.get(0);
long[] rshape = result.shape();
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);
if (nlabels > 1 && batchSize > 0) {
return logits[0][1];
}
return 0.0;
}
}