Springboot整合文心一言----非流式响应与流式响应(前后端)

时间:2024-05-01 16:42:43

        所谓非流式响应就是直接等待百度把答案生成好之后直接返回给你,而后者这是一一种流的形式,百度一边生成答案,一边将答案进行返回,这样就是我们在使用ChatGPT中最常见的一种表现了,它回答问题的时候总是一个字一个字的出来。这两回答方式都有一定的使用范围,我认为如果你需要生成的答案不是很多(通过编写对应的prompt进行限制),或者是能够接收长等待,非流式响应是没有问题的。

        但是如果你对网络连接请求有一定的要求,如前端使用Uniapp进行编码时,使用uni.uploadFile默认的超时是10s,好像还不能修改超时时间,我是没改成功。。不过这不是关键hh,当进行建立网络连接时,如果客户端超过超时时间还没有接收到服务端的消息,那就会拒绝接收了,即使你只超过零点几秒就生成出了答案,但是客户端还是会拒绝接收,所以这个时候,选择流式响应就是一个必然选择。

        本文是将流式回答在Java部分就进行过滤了,或者把流引到前端进行处理会更好,在市面上大多使用SSE技术维护整个对话,因为Uniapp不支持这个技术,所以我使用了websocket进行维护,大致相同

依赖引入:

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp</artifactId>
            <version>4.9.3</version>
        </dependency>

前端部分:

            //断线重连
            reconnect() {
				if(this.ohHideFlag)
				if (!this.is_open_socket) {
					this.reconnectTimeOut = setTimeout(() => {
						this.connectSocketInit();
					}, 3000)
				}
			},
			connectSocketInit() {
				let token = getToken()
				this.socketTask = uni.connectSocket({
                    //如果是http则使用ws,如果是https则使用wss,小程序需要去公众平台进行记录
					url: 'wss://' + this.socketUrl + '/websocket/' + token,
					success: () => {
						console.log("正准备建立websocket中...");
						// 返回实例
						return this.socketTask
					},
				});
				this.socketTask.onOpen((res) => {
					console.log("WebSocket连接正常!");
					this.is_open_socket = true;
					this.socketTask.onMessage((res) => {
						if (result == "") {
							return;
							console.log("回答完毕")
						}
						let jsonString = res.data
						const dataPrefix = "data: ";
						if (jsonString.startsWith(dataPrefix)) {
							jsonString = jsonString.substring(dataPrefix.length);
						}
						
						// 解析JSON字符串
						const jsonObject = JSON.parse(jsonString);
						
						// 获取result属性
						const result = jsonObject.result;
						console.log(result);
						this.tempItem.content += result
						this.scrollToBottom();
					});
				})
				this.socketTask.onClose(() => {
					console.log("已经被关闭了")
					this.is_open_socket = false;
					this.reconnect();
				})
			},

后端代码:

package com.farm.controller;

import com.farm.chat.StreamChat;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArraySet;


@Slf4j
@Component
@ServerEndpoint("/websocket/{target}") //创建ws的请求路径。
public class WebsocketServerEndpoint {
    private Session session;
    private String target;
    //支持持续流推送
    private InputStream inputStream;
    private final static CopyOnWriteArraySet<WebsocketServerEndpoint> websockets = new CopyOnWriteArraySet<>();

    @OnOpen
    public void onOpen(Session session , @PathParam("target") String target){
        this.session = session;
        this.target = target;
        websockets.add(this);
        log.info("websocket connect server success , target is {},total is {}",target,websockets.size());
    }

//当客户端主动联系就会触发这个方法
    @OnMessage
    public void onMessage(String message) throws IOException, JSONException {
        log.info("message is {}",message);
        JSONObject jsonObject = new JSONObject(message);
        String user = (String)jsonObject.get("user");
        String question = (String)jsonObject.get("message");

        StreamChat streamChat = new StreamChat();
        ResponseBody body = streamChat.getAnswerStream(question);
        InputStream inputStream = body.byteStream();

        sendMessageSync(user,inputStream);
    }

    @OnClose
    public void onClose(){
        log.info("connection has been closed ,target is {},total is {}" ,this.target, websockets.size());
        this.destroy();
    }

    @OnError
    public void onError(Throwable throwable){
        this.destroy();
        log.info("websocket connect error , target is {} ,total is {}, error is {}",this.target ,websockets.size(),throwable.getMessage());
    }

    /**
     * 根据目标身份推送消息
     * @param target
     * @param message
     * @throws IOException
     */
    public void sendMessageOnce(String target, String message) throws IOException {
        this.sendMessage(target,message,false,null);
    }

    /**
     * stream 同步日志输出,通过websocket推送至前台。
     * @param target
     * @param is
     * @throws IOException
     */
    private void sendMessageSync(String target, InputStream is) throws IOException {
        WebsocketServerEndpoint websocket = getWebsocket(target);
        if (Objects.isNull(websocket)) {
            throw new RuntimeException("The websocket does not exist or has been closed.");
        }
        if (Objects.isNull(is)) {
            throw new RuntimeException("InputStream cannot be null.");
        } else {
            websocket.inputStream = is;
            CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
        }
    }


    /**
     * Send message.
     * @param target 通过target获取{@link WebsocketServerEndpoint}.
     * @param message message
     * @param continuous 是否通过inputStream持续推送消息。
     * @param is 输入流
     * @throws IOException
     */
    private void sendMessage(String target , String message ,Boolean continuous , InputStream is) throws IOException {
        WebsocketServerEndpoint websocket = getWebsocket(target);
        if(Objects.isNull(websocket)){
            throw new RuntimeException("The websocket does not exists or has been closed.");
        }
        if(continuous){
            if(Objects.isNull(is)){
                throw new RuntimeException("InputStream can not be null when continuous is true.");
            }else{
                websocket.inputStream = is;
                CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
            }
        }else{
            websocket.session.getBasicRemote().sendText(message);
        }
    }

    /**
     * 通过inputStream 持续推送消息。
     * 支持文件、消息、日志等。
     */
    private void sendMessageWithInputSteam() {
        String message;
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.inputStream));
        try {
            while ((message = bufferedReader.readLine()) != null) {
                if(message.equals(""))
                    continue;
                if (websockets.contains(this)) {
                    System.out.println(message);
                    this.session.getBasicRemote().sendText(message);
                }
            }
        } catch (IOException e) {
            log.warn("SendMessage failed {}", e.getMessage());
        } finally {
            this.closeInputStream();
        }
    }

    /**
     * 根据目标获取对应的{@link WebsocketServerEndpoint}。
     * @param target 约定标的
     * @return WebsocketServerEndpoint
     */
    private WebsocketServerEndpoint getWebsocket(String target){
        WebsocketServerEndpoint websocket = null;
        for (WebsocketServerEndpoint ws : websockets) {
            if (target.equals(ws.target)) {
                websocket = ws;
            }
        }
        return websocket;
    }

    private void closeInputStream(){
        if(Objects.nonNull(inputStream)){
            try {
                inputStream.close();
            } catch (Exception e) {
                log.warn("websocket close failed {}",e.getMessage());
            }
        }
    }

    private void destroy(){
        websockets.remove(this);
        this.closeInputStream();
    }
}

StreamChat

package com.farm.chat;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.GetMapping;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;


@Slf4j
public class StreamChat {
 
    //历史对话,需要按照user,assistant
    List<Map<String,String>> messages = new ArrayList<>();

    private final String ACCESS_TOKEN_URI = "https://aip.baidubce.com/oauth/2.0/token";
    private final String CHAT_URI = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-preview";
//这里填入自己的识别码即可
    private String apiKey = " ";
    private String secretKey = " ";
    private int responseTimeOut = 5000;
    private OkHttpClient client ;
    private String accessToken = "";


    public boolean getAccessToken(){
        this.client = new OkHttpClient.Builder().readTimeout(responseTimeOut, TimeUnit.SECONDS).build();
        MediaType mediaType = MediaType.parse("application/json");
        RequestBody body = RequestBody.create(mediaType, "");
        //创建一个请求
        Request request = new Request.Builder()
                .url(ACCESS_TOKEN_URI+"?client_id=" + apiKey + "&client_secret=" + secretKey + "&grant_type=client_credentials")
                .method("POST",body)
                .addHeader("Content-Type", "application/json")
                .build();
        try {
            //使用浏览器对象发起请求
            Response response = client.newCall(request).execute();
            //只能执行一次response.body().string()。下次再执行会抛出流关闭异常,因此需要一个对象存储返回结果
            String responseMessage = response.body().string();
            log.debug("获取accessToken成功");
            JSONObject jsonObject = JSON.parseObject(responseMessage);
            accessToken = (String) jsonObject.get("access_token");
            return true;
        } catch (IOException e) {
            e.printStackTrace();
        }
        return false;
    }
    public ResponseBody getAnswerStream(String question){
        getAccessToken();
        OkHttpClient client = new OkHttpClient();

        HashMap<String, String> user = new HashMap<>();
        user.put("role","user");
        user.put("content",question);
        messages.add(user);
        String requestJson = constructRequestJson(1,0.95,0.8,1.0,true,messages);
        RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
        Request request = new Request.Builder()
                .url(CHAT_URI + "?access_token="+accessToken)
                .method("POST", body)
                .addHeader("Content-Type", "application/json")
                .build();

        StringBuilder answer = new StringBuilder();
        // 发起异步请求
        try {
            Response response = client.newCall(request).execute();
            // 检查响应是否成功
            if (response.isSuccessful()) {
                // 获取响应流
                return response.body();
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return null;
    }
    
 
    /**
     * 构造请求的请求参数
     * @param userId
     * @param temperature
     * @param topP
     * @param penaltyScore
     * @param messages
     * @return
     */
    public String constructRequestJson(Integer userId,
                                       Double temperature,
                                       Double topP,
                                       Double penaltyScore,
                                       boolean stream,
                                       List<Map<String, String>> messages) {
        Map<String,Object> request = new HashMap<>();
        request.put("user_id",userId.toString());
        request.put("temperature",temperature);
        request.put("top_p",topP);
        request.put("penalty_score",penaltyScore);
        request.put("stream",stream);
        request.put("messages",messages);
        System.out.println(JSON.toJSONString(request));
        return JSON.toJSONString(request);
    }

}

效果如下(或许前端要对字符串进行一个切割,使其有种一个字一个字出来的感觉):