Flutter+SpringBoot实现ChatGPT流实输出

时间:2025-02-16 18:27:19
import com.alibaba.fastjson2.JSON; import com.squareup.okhttp.Call; import com.squareup.okhttp.MediaType; import com.squareup.okhttp.OkHttpClient; import com.squareup.okhttp.Request; import com.squareup.okhttp.RequestBody; import com.squareup.okhttp.Response; import com.squareup.okhttp.ResponseBody; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import vip.ailtw.common.utils.StringUtil; import javax.annotation.PostConstruct; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Serializable; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; @Slf4j @Component public class ChatGptStreamUtil { /** * 修改为自己的密钥 */ private final String apiKey = "xxxxxxxxxxxxxx"; public final String gptCompletionsUrl = "/v1/chat/completions"; private static final OkHttpClient client = new OkHttpClient(); private static MediaType mediaType; private static Request.Builder requestBuilder; public final static Pattern contentPattern = Pattern.compile("\"content\":\"(.*?)\"}"); /** * 对话符号 */ public final static String EVENT_DATA = "d"; /** * 错误结束符号 */ public final static String EVENT_ERROR = "e"; /** * 响应结束符号 */ public final static String END = "<<END>>"; @PostConstruct public void init() { client.setConnectTimeout(60, TimeUnit.SECONDS); client.setReadTimeout(60, TimeUnit.SECONDS); mediaType = MediaType.parse("application/json; charset=utf-8"); requestBuilder = new Request.Builder() .url(gptCompletionsUrl) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + apiKey); } /** * 流式对话 * * @param talkList 上下文对话,最早的对话放在首位 * @param callable 消费者,流式对话每次响应的内容 */ public GptChatResultDTO chatStream(List<ChatGptDTO> talkList, Consumer<String> callable) throws Exception { long start = System.currentTimeMillis(); StringBuilder resp = new StringBuilder(); Response response = chatStream(talkList); //解析对话内容 try (ResponseBody responseBody = response.body(); InputStream inputStream = responseBody.byteStream(); BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) { String line; while ((line = bufferedReader.readLine()) != null) { if (!StringUtils.hasLength(line)) { continue; } Matcher matcher = contentPattern.matcher(line); if (matcher.find()) { String content = matcher.group(1); resp.append(content); callable.accept(content); } } } int wordSize = 0; for (ChatGptDTO dto : talkList) { String content = dto.getContent(); wordSize += content.toCharArray().length; } wordSize += resp.toString().toCharArray().length; long end = System.currentTimeMillis(); return GptChatResultDTO.builder().resContent(resp.toString()).time(end - start).wordSize(wordSize).build(); } /** * 流式对话 * * @param talkList 上下文对话 * @return 接口请求响应 */ private Response chatStream(List<ChatGptDTO> talkList) throws Exception { ChatStreamDTO chatStreamDTO = new ChatStreamDTO(talkList); RequestBody bodyOk = RequestBody.create(mediaType, chatStreamDTO.toString()); Request requestOk = requestBuilder.post(bodyOk).build(); Call call = client.newCall(requestOk); Response response; try { response = call.execute(); } catch (IOException e) { throw new IOException("请求时IO异常: " + e.getMessage()); } if (response.isSuccessful()) { return response; } try (ResponseBody body = response.body()) { if (429 == response.code()) { String msg = "Open Api key 已过期,msg: " + body.string(); log.error(msg); } throw new RuntimeException("chat api 请求异常, code: " + response.code() + "body: " + body.string()); } } private boolean sendToClient(String event, String data, SseEmitter emitter) { try { emitter.send(SseEmitter.event().name(event).data("{" + data + "}")); return true; } catch (IOException e) { log.error("向客户端发送消息时出现异常", e); } return false; } /** * 发送事件给客户端 */ public boolean sendData(String data, SseEmitter emitter) { if (StringUtil.isBlank(data)) { return true; } return sendToClient(EVENT_DATA, data, emitter); } /** * 发送结束事件,会关闭emitter */ public void sendEnd(SseEmitter emitter) { try { sendToClient(EVENT_DATA, END, emitter); } finally { emitter.complete(); } } /** * 发送异常事件,会关闭emitter */ public void sendError(SseEmitter emitter) { try { sendToClient(EVENT_ERROR, "我累垮了", emitter); } finally { emitter.complete(); } } /** * gpt请求结果 */ @Data @NoArgsConstructor @AllArgsConstructor @Builder public static class GptChatResultDTO implements Serializable { /** * gpt请求返回的全部内容 */ private String resContent; /** * 上下文消耗的字数 */ private int wordSize; /** * 耗时 */ private long time; } /** * 连续对话DTO */ @Data @Builder @NoArgsConstructor @AllArgsConstructor public static class ChatGptDTO implements Serializable { /** * 对话内容 */ private String content; /** * 角色 {@link GptRoleEnum} */ private String role; } /** * gpt连续对话角色 */ @Getter public static enum GptRoleEnum { USER_ROLE("user", "用户"), GPT_ROLE("assistant", "ChatGPT本身"), /** * message里role为system,是为了让ChatGPT在对话过程中设定自己的行为 * 可以理解为对话的设定,如你是谁,要什么语气、等级 */ SYSTEM_ROLE("system", "对话设定"), ; private final String value; private final String desc; GptRoleEnum(String value, String desc) { this.value = value; this.desc = desc; } } /** * gpt请求body */ @Data public static class ChatStreamDTO { private static final String model = "gpt-3.5-turbo"; private static final boolean stream = true; private List<ChatGptDTO> messages; public ChatStreamDTO(List<ChatGptDTO> messages) { this.messages = messages; } @Override public String toString() { return "{\"model\":\"" + model + "\"," + "\"messages\":" + JSON.toJSONString(messages) + "," + "\"stream\":" + stream + "}"; } } }

相关文章