✨ SSE 版本
This commit is contained in:
parent
d330b6b74c
commit
2af03cbf13
@ -1,12 +1,12 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"catface/app/global/consts"
|
||||
"catface/app/global/errcode"
|
||||
"catface/app/global/variable"
|
||||
"catface/app/model_es"
|
||||
"catface/app/service/nlp"
|
||||
"catface/app/utils/response"
|
||||
"io"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@ -15,9 +15,40 @@ import (
|
||||
type Rag struct {
|
||||
}
|
||||
|
||||
func (r *Rag) Chat(context *gin.Context) {
|
||||
// v1 Http-POST 版本
|
||||
// func (r *Rag) Chat(context *gin.Context) {
|
||||
// // 1. query embedding
|
||||
// query := context.GetString(consts.ValidatorPrefix + "query")
|
||||
// embedding, ok := nlp.GetEmbedding(query)
|
||||
// if !ok {
|
||||
// code := errcode.ErrPythonService
|
||||
// response.Fail(context, code, errcode.ErrMsg[code], "")
|
||||
// return
|
||||
// }
|
||||
|
||||
// // 2. ES TopK
|
||||
// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
||||
// if err != nil || len(docs) == 0 {
|
||||
// variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||
|
||||
// code := errcode.ErrNoDocFound
|
||||
// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
|
||||
// }
|
||||
|
||||
// // 3. LLM answer
|
||||
// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
|
||||
// response.Success(context, consts.CurdStatusOkMsg, gin.H{
|
||||
// "answer": answer,
|
||||
// })
|
||||
// } else {
|
||||
// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
|
||||
// }
|
||||
// }
|
||||
|
||||
func (r *Rag) ChatSSE(context *gin.Context) {
|
||||
query := context.Query("query")
|
||||
|
||||
// 1. query embedding
|
||||
query := context.GetString(consts.ValidatorPrefix + "query")
|
||||
embedding, ok := nlp.GetEmbedding(query)
|
||||
if !ok {
|
||||
code := errcode.ErrPythonService
|
||||
@ -34,14 +65,31 @@ func (r *Rag) Chat(context *gin.Context) {
|
||||
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
|
||||
}
|
||||
|
||||
// UPDATE
|
||||
closeEventFromVue := context.Request.Context().Done()
|
||||
ch := make(chan string) // TIP 建立通道。
|
||||
|
||||
// 3. LLM answer
|
||||
if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
|
||||
response.Success(context, consts.CurdStatusOkMsg, gin.H{
|
||||
"answer": answer,
|
||||
})
|
||||
} else {
|
||||
response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
|
||||
go func() {
|
||||
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch)
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case c, ok := <-ch:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
context.SSEvent("chat", c)
|
||||
return true
|
||||
case <-closeEventFromVue:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
||||
|
@ -25,7 +25,7 @@ func (c Chat) CheckParams(context *gin.Context) {
|
||||
if extraAddBindDataContext == nil {
|
||||
response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "")
|
||||
} else {
|
||||
(&web.Rag{}).Chat(extraAddBindDataContext)
|
||||
(&web.Rag{}).ChatSSE(extraAddBindDataContext)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ func GenerateTitle(content string) string {
|
||||
}
|
||||
|
||||
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
|
||||
func ChatKnoledgeRAG(doc, query string) (string, error) {
|
||||
func ChatKnoledgeRAG(doc, query string, ch chan<- string) error {
|
||||
// 读取配置文件中的 KnoledgeRAG 模板
|
||||
promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG")
|
||||
|
||||
@ -23,10 +23,10 @@ func ChatKnoledgeRAG(doc, query string) (string, error) {
|
||||
message = strings.Replace(message, "{context}", doc, -1)
|
||||
|
||||
// 调用聊天接口
|
||||
response, err := glm.Chat(message)
|
||||
err := glm.ChatStream(message, ch)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("调用聊天接口失败: %w", err)
|
||||
return fmt.Errorf("调用聊天接口失败: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return nil
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/yankeguo/zhipu"
|
||||
|
||||
)
|
||||
|
||||
// ChatWithGLM 封装了与GLM模型进行对话的逻辑
|
||||
@ -24,3 +25,24 @@ func Chat(message string) (string, error) {
|
||||
|
||||
return res.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// ChatStream 接收一个消息和一个通道,将流式响应发送到通道中
|
||||
func ChatStream(message string, ch chan<- string) error {
|
||||
service := variable.GlmClient.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{Role: "user", Content: message}).
|
||||
SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
ch <- content // 将内容发送到通道
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// 执行服务调用
|
||||
_, err := service.Do(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ func InitWebRouter() *gin.Engine {
|
||||
|
||||
rag := backend.Group("rag")
|
||||
{
|
||||
rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat"))
|
||||
rag.GET("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat"))
|
||||
}
|
||||
|
||||
search := backend.Group("search")
|
||||
|
41
test/sse/sse.html
Normal file
41
test/sse/sse.html
Normal file
@ -0,0 +1,41 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<title>SSE test</title>
|
||||
<script type="text/javascript">
|
||||
// 向后端服务器发起sse请求
|
||||
const es = new EventSource("http://127.0.0.1:20201/admin/rag/default_talk");
|
||||
// 监听事件流
|
||||
es.onmessage = function (e) {
|
||||
document.getElementById("test")
|
||||
.insertAdjacentHTML("beforeend", "<li>" + e.data + "</li>");
|
||||
console.log(e);
|
||||
}
|
||||
// 监听”chat“事件流
|
||||
es.addEventListener("chat", (e) => {
|
||||
document.getElementById("test")
|
||||
.insertAdjacentHTML("beforeend", "<a>" + e.data + "</a>");
|
||||
console.log(e)
|
||||
});
|
||||
es.onerror = function (e) {
|
||||
// readyState说明
|
||||
// 0:浏览器与服务端尚未建立连接或连接已被关闭
|
||||
// 1:浏览器与服务端已成功连接,浏览器正在处理接收到的事件及数据
|
||||
// 2:浏览器与服务端建立连接失败,客户端不再继续建立与服务端之间的连接
|
||||
console.log("readyState = " + e.currentTarget.readyState);
|
||||
// 关闭连接
|
||||
es.close();
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<h1>SSE test</h1>
|
||||
<div>
|
||||
<ul id="test">
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
97
test/sse/sse_test.go
Normal file
97
test/sse/sse_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
//后端代码
|
||||
|
||||
//注意 **我注释的代码,是不使用gin框架封装的Stream方法,也就是C.Stream(func())和C.ssevent(),只是C.Stream要改成for循环持续的从通道里面进行读,直到通道关闭,结束for循环**
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"catface/app/service/nlp/glm"
|
||||
_ "catface/bootstrap"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
// "time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
)
|
||||
|
||||
func SSE(c *gin.Context) {
|
||||
// 设置响应头,告诉前端适用event-stream事件流交互
|
||||
//c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
//c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
//c.Writer.Header().Set("Connection", "keep-alive")
|
||||
|
||||
// 判断是否支持sse
|
||||
//w := c.Writer
|
||||
//flusher, _ := w.(http.Flusher)
|
||||
query := c.Query("query")
|
||||
|
||||
// 接收前端页面关闭连接通知
|
||||
closeNotify := c.Request.Context().Done()
|
||||
|
||||
// 开启协程监听前端页面是否关闭了连接,关闭连接会触发此方法
|
||||
go func() {
|
||||
<-closeNotify
|
||||
fmt.Println("SSE关闭了")
|
||||
return
|
||||
}()
|
||||
|
||||
//新建一个通道,用于数据接收和响应
|
||||
Chan := make(chan string)
|
||||
|
||||
// 异步接收GPT响应,然后把响应的数据发送到通道Chan
|
||||
go func() {
|
||||
err := glm.ChatStream(query, Chan)
|
||||
if err != nil {
|
||||
fmt.Println("Error", err)
|
||||
}
|
||||
|
||||
close(Chan)
|
||||
}()
|
||||
|
||||
// gin框架封装的stream,会持续的调用这个func方法,记得返回true;返回false代表结束调用func方法
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case i, ok := <-Chan:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
c.SSEvent("chat", i) // c.SSEvent会自动修改响应头为事件流,并发送”test“事件流给前端监听”test“的回调方法
|
||||
//flusher.Flush() // 确保立即发送
|
||||
return true
|
||||
case <-closeNotify:
|
||||
fmt.Println("SSE关闭了")
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSE(t *testing.T) {
|
||||
engine := gin.Default()
|
||||
// 设置跨域中间件
|
||||
engine.Use(func(context *gin.Context) {
|
||||
origin := context.GetHeader("Origin")
|
||||
// 允许 Origin 字段中的域发送请求
|
||||
context.Writer.Header().Add("Access-Control-Allow-Origin", origin) // 这边我的前端页面在63342,会涉及跨域,这个根据自己情况设置,或者直接设置为”*“,放行所有的
|
||||
// 设置预验请求有效期为 86400 秒
|
||||
context.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
// 设置允许请求的方法
|
||||
context.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, U`PDATE, PATCH")
|
||||
// 设置允许请求的 Header
|
||||
context.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Apitoken")
|
||||
// 设置拿到除基本字段外的其他字段,如上面的Apitoken, 这里通过引用Access-Control-Expose-Headers,进行配置,效果是一样的。
|
||||
context.Writer.Header().Set("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Headers")
|
||||
// 配置是否可以带认证信息
|
||||
context.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
// OPTIONS请求返回200
|
||||
if context.Request.Method == "OPTIONS" {
|
||||
fmt.Println(context.Request.Header)
|
||||
context.AbortWithStatus(200)
|
||||
} else {
|
||||
context.Next()
|
||||
}
|
||||
})
|
||||
engine.GET("/admin/rag/default_talk", SSE) // TIP 记得适用get请求,我用post前端报404,资料说是SSE只支持get请求
|
||||
engine.Run(":20201")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user