使用请求调用本地部署的stable-diffusion接口

时间:2025-03-28 07:04:02
public class StableDiffusionUtil { private static final String BASE_URL = "http://127.0.0.1:7860"; private static final OkHttpClient CLIENT = new OkHttpClient(); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); /** * 获取可用模型列表 * * @return 模型列表的 JSON 字符串 * @throws IOException 如果请求失败 */ public static List<ModelInfo> listModels() throws IOException { Request request = new Request.Builder() .url(BASE_URL + "/sdapi/v1/sd-models") .get() .build(); try (Response response = CLIENT.newCall(request).execute()) { if (response.isSuccessful()) { String jsonResponse = response.body().string(); // 将 JSON 字符串解析为 ModelInfo 对象列表 return OBJECT_MAPPER.readValue(jsonResponse, new TypeReference<List<ModelInfo>>() {}); } else { throw new IOException("请求失败: " + response.code()); } } } /** * 生成图片并保存 * * @param prompt 生成图片的描述 * @param negativePrompt 负面描述 * @param steps 生成步骤 * @param cfgScale CFG 参数 * @param width 图片宽度 * @param height 图片高度 * @param samplerIndex 采样器 * @param modelCheckpoint 模型名称 * @param outputFilePath 保存图片的文件路径 * @throws IOException 如果请求失败或保存图片失败 */ public static byte[] generateImage(String prompt, String negativePrompt, int steps, int cfgScale, int width, int height, String samplerIndex, String modelCheckpoint) throws IOException { // 请求参数 String json = "{" + "\"prompt\": \"" + prompt + "\"," + "\"negative_prompt\": \"" + negativePrompt + "\"," + "\"steps\": " + steps + "," + "\"cfg_scale\": " + cfgScale + "," + "\"width\": " + width + "," + "\"height\": " + height + "," + "\"sampler_index\": \"" + samplerIndex + "\"," + "\"sd_model_checkpoint\": \"" + modelCheckpoint + "\"" + "}"; RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8")); Request request = new Request.Builder() .url(BASE_URL + "/sdapi/v1/txt2img") .post(body) .build(); try (Response response = CLIENT.newCall(request).execute()) { if (response.isSuccessful()) { String jsonResponse = response.body().string(); String imageBase64 = jsonResponse.split("\"images\":\\[\"")[1].split("\"")[0]; // 解码 Base64 图片 byte[] imageBytes = Base64.getDecoder().decode(imageBase64); ByteArrayInputStream bis = new ByteArrayInputStream(imageBytes); BufferedImage image = ImageIO.read(bis); // 将图片转换为二进制数组 ByteArrayOutputStream baos = new ByteArrayOutputStream(); ImageIO.write(image, "png", baos); return baos.toByteArray(); } else { throw new IOException("请求失败: " + response.code()); } } } /** * 切换模型 * * @param modelCheckpoint 模型名称 * @throws IOException 如果请求失败 */ public static void switchModel(String modelCheckpoint) throws IOException { // 更新模型配置 String json = "{" + "\"sd_model_checkpoint\": \"" + modelCheckpoint + "\"" + "}"; RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8")); Request request = new Request.Builder() .url(BASE_URL + "/sdapi/v1/options") .post(body) .build(); try (Response response = CLIENT.newCall(request).execute()) { if (response.isSuccessful()) { System.out.println("模型已切换到 " + modelCheckpoint); } else { throw new IOException("请求失败: " + response.code()); } } } }