背景
springboot對接gpt,實現(xiàn)流式對話傳輸
后端
接口層
為了實現(xiàn)對話流式傳輸,需要設(shè)置接口返回類型,同時設(shè)置下響應(yīng)Header(Header不添加也可以)
@PostMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter sse(@Validated @RequestBody AnalyzeChatVO vo, HttpServletResponse response) {
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
return gptService.analyzeChatStream(vo);
}
實現(xiàn)
需要注意的是,輸出的內(nèi)容需要異步返回,你用線程池或者線程都可以,只需要異步就可以了
public SseEmitter analyzeChatStream(AnalyzeChatVO vo) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
if (StringUtils.isEmpty(vo.getUser())) {
vo.setUser(UsernameHolder.getUsername());
}
ThreadPoolExecutor executor = ThreadPoolUtil.simpleThreadPool("chat", 1, 1);
try {
CompletableFuture.runAsync(() -> streamRequest(vo, new SseListener(emitter, this)), executor).whenComplete((r, t) -> {
if (t != null) {
emitter.completeWithError(t);
log.error("Stream request start error,", t);
}
});
} finally {
executor.shutdown();
}
return emitter;
}
這里的監(jiān)聽器是通過okhttp來實現(xiàn)的,因此需要先引入okhttp的sse模塊
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.9.1</version>
</dependency>
然后將我們自定義的監(jiān)聽器注冊上去,其中baseUrl就是我們模型的地址,然后攜帶對應(yīng)的token就可以了;
模型請求中攜帶的chatId,主要是為了用戶隔離
private void streamRequest(AnalyzeChatVO vo, EventSourceListener listener) {
GptClient client = getStreamClient();
String url = client.getAttribute().getBaseUrl() + "/api/v1/chat/completions";
log.info("Stream url:{}", url);
OkHttpClient okHttpClient = client.getOkHttpClient();
EventSource.Factory factory = EventSources.createFactory(okHttpClient);
String requestBody = String.format("{\"chatId\": \"%s\",\"stream\": true, \"messages\": [{\"role\":\"user\", \"content\": \"%s\"}]}",
vo.getUser(), vo.getQuestion().replace("\n", ""));
Request.Builder builder = new Request.Builder()
.url(url)
.header("Authorization", client.getToken());
.post(RequestBody.create(requestBody, okhttp3.MediaType.parse(MediaType.APPLICATION_JSON.toString())));
Request request = builder.build();
factory.newEventSource(request, listener);
}
監(jiān)聽器
這里自定義的監(jiān)聽器主要就是實現(xiàn)EventSourceListener 的相關(guān)方法;ChatCompletion主要就是定義了對話的返回結(jié)構(gòu)體,lastMessage 用于接收整個完整的返回消息,因為消息是按照流式一部分一部分返回的,這里拼接下完整消息內(nèi)容,也可以去掉
public abstract class AbstractStreamListener extends EventSourceListener {
protected String lastMessage = "";
private static final String STREAM_END = "[DONE]";
@Setter
@Getter
protected Consumer<String> onComplete = s -> {
};
public abstract void onMsg(String message);
public abstract void onError(Throwable throwable, String response);
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("Open");
}
@Override
public void onClosed(EventSource eventSource) {
log.info("Closed");
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("Event:{}", data);
if (STREAM_END.equals(data)) {
onMsg(data);
onComplete.accept(lastMessage);
return;
}
ChatCompletion response = JSON.parseObject(data, ChatCompletion.class);
String text = response.toPlainStringStream();
Map<String, String> dataToSend = Maps.newHashMap();
dataToSend.put("content", text);
if (StringUtils.isNotEmpty(text)) {
lastMessage += text;
// fix to raw data, avoid '\n' messages be resolved
onMsg(JSON.toJSONString(dataToSend));
}
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
log.info("Fail", throwable);
try {
String responseText = "";
if (Objects.nonNull(response) && Objects.nonNull(response.body())) {
responseText = response.body().string();
}
log.error("Listener failure response:{}", responseText);
this.onError(throwable, responseText);
} catch (Exception e) {
log.error("Listener on failure error,", e);
} finally {
eventSource.cancel();
}
}
}
最終的實現(xiàn)在SseListener 中,將監(jiān)聽器中收到的消息轉(zhuǎn)發(fā)到SseEmitter;同時在消息完成后,打印一下完整的消息內(nèi)容
public class SseListener extends AbstractStreamListener {
private SseEmitter emitter;
public SseListener(SseEmitter emitter) {
this.emitter = emitter;
super.setOnComplete((s) -> {
log.info("Complete message:{}", s);
emitter.complete();
});
}
@Override
public void onMsg(String message) {
log.info(message);
try {
emitter.send(message);
} catch (IOException e) {
log.error("Send message error,", e);
}
}
@Override
public void onError(Throwable throwable, String response) {
log.error("Listener error: {}", response, throwable);
emitter.completeWithError(throwable);
}
}
這樣后端的簡單實現(xiàn)就算是完成了
Nginx
如果你的項目中請求是通過nginx代理的,那么還需要調(diào)整下nginx的配置,主要是添加proxy_redirect off; proxy_buffering off;這兩個配置,關(guān)閉nginx的緩存功能
location /sse {
add_header Access-Control-Allow-Origin *;
add_header Access-Control-Allow-Methods 'GET, POST, OPTIONS';
add_header Access-Control-Allow-Headers 'DNT,X-Mx-ReqToken,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization';
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header Host $http_host;
proxy_redirect off;
proxy_buffering off;
proxy_cache off;
proxy_pass http://upstream;
}