java加载tensorflow训练好的模型部署成service

时间:2022-09-16 22:10:08

在上面一章节提到怎么在java中怎么调用tensorflow训练好的模型,这篇主要是部署成service代码,看看吧,还有个东西官方说要用jdk1.8,不过我把部分方法改了,1.7也可以用,看看吧:


首先是utils,里面用到的一些方法,把一段文本转化为一个tensor

    
    
   
   
  1. package com.dianping.text.classify.util;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.IOException;
  5. import java.io.InputStream;
  6. import java.io.InputStreamReader;
  7. import java.nio.file.Files;
  8. import java.nio.file.Paths;
  9. import java.nio.file.Path;
  10. import java.util.ArrayList;
  11. import java.util.Arrays;
  12. import java.util.HashMap;
  13. import java.util.List;
  14. import java.util.Map;
  15. import org.apache.commons.lang.StringUtils;
  16. import org.tensorflow.Graph;
  17. import org.tensorflow.Session;
  18. import org.tensorflow.Tensor;
  19. public class TersorflowUtils {
  20. private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();
  21. public static byte[] readAllBytesOrExit(Path path) {
  22. try {
  23. return Files.readAllBytes(path);
  24. } catch (IOException e) {
  25. System.err.println("Failed to read [" + path + "]: " + e.getMessage());
  26. System.exit(1);
  27. }
  28. return null;
  29. }
  30. public static byte[] readAllBytes(String path) {
  31. try {
  32. InputStream in = new FileInputStream(path);
  33. byte[] bytes = new byte[in.available()];
  34. in.read(bytes);
  35. in.close();
  36. return bytes;
  37. } catch (Exception e) {
  38. return null;
  39. }
  40. }
  41. /*
  42. * 序列默人长度为300
  43. */
  44. public static int[][] gettexttoid(String text) {
  45. int[][] xpad = new int[1][300];
  46. if (StringUtils.isBlank(text)) {
  47. return xpad;
  48. }
  49. char[] chs = text.trim().toLowerCase().toCharArray();
  50. List<Integer> list = new ArrayList<Integer>();
  51. for (int i = 0; i < chs.length; i++) {
  52. String element = Character.toString(chs[i]);
  53. if (word_to_id.containsKey(element)) {
  54. list.add(word_to_id.get(element));
  55. }
  56. }
  57. if (list.size() == 0) {
  58. return xpad;
  59. }
  60. int size = list.size();
  61. Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
  62. /*
  63. * 用于jdk1.8转换
  64. */
  65. // int[] target=
  66. // Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  67. int[] target = Intetoint(targetInter);
  68. if (size <= 300) {
  69. System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
  70. } else {
  71. System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
  72. }
  73. return xpad;
  74. }
  75. public static int[][] gettexttoid(String text, Map<String, Integer> map) {
  76. int[][] xpad = new int[1][300];
  77. if (StringUtils.isBlank(text)) {
  78. return xpad;
  79. }
  80. char[] chs = text.trim().toLowerCase().toCharArray();
  81. List<Integer> list = new ArrayList<Integer>();
  82. for (int i = 0; i < chs.length; i++) {
  83. String element = Character.toString(chs[i]);
  84. if (map.containsKey(element)) {
  85. list.add(map.get(element));
  86. }
  87. }
  88. if (list.size() == 0) {
  89. return xpad;
  90. }
  91. int size = list.size();
  92. Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
  93. /*
  94. * 用于jdk1.8转换
  95. */
  96. // int[] target=
  97. // Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  98. int[] target = Intetoint(targetInter);
  99. if (size <= 300) {
  100. System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
  101. } else {
  102. System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
  103. }
  104. return xpad;
  105. }
  106. /*
  107. * 自定义长度
  108. */
  109. public static int[][] gettexttoid(String text, int maxlen) {
  110. if (maxlen < 1) {
  111. throw new IllegalArgumentException("maxlen长度必须大于等于1");
  112. }
  113. int[][] xpad = new int[1][maxlen];
  114. if (StringUtils.isBlank(text)) {
  115. return xpad;
  116. }
  117. char[] chs = text.trim().toLowerCase().toCharArray();
  118. List<Integer> list = new ArrayList<Integer>();
  119. for (int i = 0; i < chs.length; i++) {
  120. String element = Character.toString(chs[i]);
  121. if (word_to_id.containsKey(element)) {
  122. list.add(word_to_id.get(element));
  123. }
  124. }
  125. if (list.size() == 0) {
  126. return xpad;
  127. }
  128. int size = list.size();
  129. Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
  130. /*
  131. * 用于jdk1.8转换
  132. */
  133. // int[] target=
  134. // Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  135. int[] target = Intetoint(targetInter);
  136. if (size <= maxlen) {
  137. System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
  138. } else {
  139. System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
  140. }
  141. return xpad;
  142. }
  143. public static int[][] gettexttoid(String text, int maxlen, Map<String, Integer> map) {
  144. if (maxlen < 1) {
  145. throw new IllegalArgumentException("maxlen长度必须大于等于1");
  146. }
  147. int[][] xpad = new int[1][maxlen];
  148. if (StringUtils.isBlank(text)) {
  149. return xpad;
  150. }
  151. char[] chs = text.trim().toLowerCase().toCharArray();
  152. List<Integer> list = new ArrayList<Integer>();
  153. for (int i = 0; i < chs.length; i++) {
  154. String element = Character.toString(chs[i]);
  155. if (map.containsKey(element)) {
  156. list.add(map.get(element));
  157. }
  158. }
  159. if (list.size() == 0) {
  160. return xpad;
  161. }
  162. int size = list.size();
  163. Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
  164. /*
  165. * 用于jdk1.8转换
  166. */
  167. // int[] target=
  168. // Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
  169. int[] target = Intetoint(targetInter);
  170. if (size <= maxlen) {
  171. System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
  172. } else {
  173. System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
  174. }
  175. return xpad;
  176. }
  177. private static int[] Intetoint(Integer[] arr) {
  178. int[] result = new int[arr.length];
  179. for (int i = 0; i < arr.length; i++) {
  180. result[i] = arr[i].intValue();
  181. }
  182. return result;
  183. }
  184. public static double getClassifyByBiLSTM(String text, Session sess, Map<String, Integer> map, Tensor keep_prob) {
  185. if (StringUtils.isBlank(text)) {
  186. return 0.0;
  187. }
  188. int[][] arr = gettexttoid(text, map);
  189. Tensor input = Tensor.create(arr);
  190. Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", keep_prob).fetch("score/pred_y").run()
  191. .get(0);
  192. long[] rshape = result.shape();
  193. int nlabels = (int) rshape[1];
  194. int batchSize = (int) rshape[0];
  195. float[][] logits = result.copyTo(new float[batchSize][nlabels]);
  196. if (nlabels > 1 && batchSize > 0) {
  197. return logits[0][1];
  198. }
  199. return 0.0;
  200. }
  201. }


其次是service启动项:

    
    
   
   
  1. package com.dianping.text.classify.base;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.InputStreamReader;
  5. import java.nio.file.Paths;
  6. import java.util.HashMap;
  7. import java.util.Map;
  8. import org.slf4j.Logger;
  9. import org.slf4j.LoggerFactory;
  10. import org.tensorflow.Graph;
  11. import org.tensorflow.Session;
  12. import org.tensorflow.Tensor;
  13. import com.dianping.text.classify.util.TersorflowUtils;
  14. import com.dianping.text.classifybydl.api.service.Category;
  15. public class Abuse implements Category {
  16. private static final Logger logger = LoggerFactory.getLogger(Abuse.class);
  17. private Graph g;
  18. private Session sess;
  19. private Tensor keep_prob;
  20. private Map<String, Integer> map;
  21. private void init() {
  22. g = new Graph();
  23. keep_prob = Tensor.create(1.0f);
  24. try {
  25. updataMap();
  26. byte[] graphDef = TersorflowUtils
  27. .readAllBytesOrExit(Paths.get(this.getClass().getResource("/").getPath(), "modelabuse/graph.model"));
  28. g.importGraphDef(graphDef);
  29. sess = new Session(g);
  30. } catch (Exception e) {
  31. logger.error(" model load:", e);
  32. }
  33. }
  34. public void updataMap() {
  35. map = new HashMap<>();
  36. int i = 0;
  37. try {
  38. BufferedReader buffer = null;
  39. String path = this.getClass().getResource("/").getPath() + "modelabuse/vocab_cnews.txt";
  40. buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
  41. String line = buffer.readLine().trim();
  42. while (line != null) {
  43. map.put(line, i++);
  44. line = buffer.readLine().trim();
  45. }
  46. buffer.close();
  47. } catch (Exception e) {
  48. }
  49. System.out.println("map.size is:" + map.size());
  50. }
  51. @Override
  52. public double getClassify(String text) {
  53. return TersorflowUtils.getClassifyByBiLSTM(text, sess, map, keep_prob);
  54. }
  55. public static void main(String[] args) {
  56. Abuse abuse = new Abuse();
  57. abuse.init();
  58. System.out.println(abuse.getClassify("我操你妈个逼"));
  59. }
  60. }



结果:

java加载tensorflow训练好的模型部署成service

java加载tensorflow训练好的模型部署成service