From 2af03cbf13befb664235f39c00ca6547e216b3de Mon Sep 17 00:00:00 2001 From: Havoc412 <2993167370@qq.com> Date: Sat, 16 Nov 2024 14:00:57 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20SSE=20=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/http/controller/web/rag_controller.go | 68 +++++++++++++--- app/http/validator/web/rag/chat.go | 2 +- app/service/nlp/func.go | 8 +- app/service/nlp/glm/glm.go | 22 +++++ routers/web.go | 2 +- test/sse/sse.html | 41 ++++++++++ test/sse/sse_test.go | 97 +++++++++++++++++++++++ 7 files changed, 224 insertions(+), 16 deletions(-) create mode 100644 test/sse/sse.html create mode 100644 test/sse/sse_test.go diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index 065e7a8..753da5f 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -1,12 +1,12 @@ package web import ( - "catface/app/global/consts" "catface/app/global/errcode" "catface/app/global/variable" "catface/app/model_es" "catface/app/service/nlp" "catface/app/utils/response" + "io" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -15,9 +15,40 @@ import ( type Rag struct { } -func (r *Rag) Chat(context *gin.Context) { +// v1 Http-POST 版本 +// func (r *Rag) Chat(context *gin.Context) { +// // 1. query embedding +// query := context.GetString(consts.ValidatorPrefix + "query") +// embedding, ok := nlp.GetEmbedding(query) +// if !ok { +// code := errcode.ErrPythonService +// response.Fail(context, code, errcode.ErrMsg[code], "") +// return +// } + +// // 2. ES TopK +// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1) +// if err != nil || len(docs) == 0 { +// variable.ZapLog.Error("ES TopK error", zap.Error(err)) + +// code := errcode.ErrNoDocFound +// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) +// } + +// // 3. LLM answer +// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil { +// response.Success(context, consts.CurdStatusOkMsg, gin.H{ +// "answer": answer, +// }) +// } else { +// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "") +// } +// } + +func (r *Rag) ChatSSE(context *gin.Context) { + query := context.Query("query") + // 1. query embedding - query := context.GetString(consts.ValidatorPrefix + "query") embedding, ok := nlp.GetEmbedding(query) if !ok { code := errcode.ErrPythonService @@ -34,14 +65,31 @@ func (r *Rag) Chat(context *gin.Context) { response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) } + // UPDATE + closeEventFromVue := context.Request.Context().Done() + ch := make(chan string) // TIP 建立通道。 + // 3. LLM answer - if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil { - response.Success(context, consts.CurdStatusOkMsg, gin.H{ - "answer": answer, - }) - } else { - response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "") - } + go func() { + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch) + if err != nil { + variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) + } + close(ch) + }() + + context.Stream(func(w io.Writer) bool { + select { + case c, ok := <-ch: + if !ok { + return false + } + context.SSEvent("chat", c) + return true + case <-closeEventFromVue: + return false + } + }) } func (r *Rag) HelpDetectCat(context *gin.Context) { diff --git a/app/http/validator/web/rag/chat.go b/app/http/validator/web/rag/chat.go index 2b9adbd..94cdada 100644 --- a/app/http/validator/web/rag/chat.go +++ b/app/http/validator/web/rag/chat.go @@ -25,7 +25,7 @@ func (c Chat) CheckParams(context *gin.Context) { if extraAddBindDataContext == nil { response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "") } else { - (&web.Rag{}).Chat(extraAddBindDataContext) + (&web.Rag{}).ChatSSE(extraAddBindDataContext) } } diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go index 6be43c1..4d7cfaf 100644 --- a/app/service/nlp/func.go +++ b/app/service/nlp/func.go @@ -14,7 +14,7 @@ func GenerateTitle(content string) string { } // ChatKnoledgeRAG 使用 RAG 模型进行知识问答 -func ChatKnoledgeRAG(doc, query string) (string, error) { +func ChatKnoledgeRAG(doc, query string, ch chan<- string) error { // 读取配置文件中的 KnoledgeRAG 模板 promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG") @@ -23,10 +23,10 @@ func ChatKnoledgeRAG(doc, query string) (string, error) { message = strings.Replace(message, "{context}", doc, -1) // 调用聊天接口 - response, err := glm.Chat(message) + err := glm.ChatStream(message, ch) if err != nil { - return "", fmt.Errorf("调用聊天接口失败: %w", err) + return fmt.Errorf("调用聊天接口失败: %w", err) } - return response, nil + return nil } diff --git a/app/service/nlp/glm/glm.go b/app/service/nlp/glm/glm.go index a03a70f..22ddda3 100644 --- a/app/service/nlp/glm/glm.go +++ b/app/service/nlp/glm/glm.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/yankeguo/zhipu" + ) // ChatWithGLM 封装了与GLM模型进行对话的逻辑 @@ -24,3 +25,24 @@ func Chat(message string) (string, error) { return res.Choices[0].Message.Content, nil } + +// ChatStream 接收一个消息和一个通道,将流式响应发送到通道中 +func ChatStream(message string, ch chan<- string) error { + service := variable.GlmClient.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{Role: "user", Content: message}). + SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + content := chunk.Choices[0].Delta.Content + if content != "" { + ch <- content // 将内容发送到通道 + } + return nil + }) + + // 执行服务调用 + _, err := service.Do(context.Background()) + if err != nil { + return err + } + + return nil +} diff --git a/routers/web.go b/routers/web.go index 2cff193..ce86326 100644 --- a/routers/web.go +++ b/routers/web.go @@ -152,7 +152,7 @@ func InitWebRouter() *gin.Engine { rag := backend.Group("rag") { - rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat")) + rag.GET("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat")) } search := backend.Group("search") diff --git a/test/sse/sse.html b/test/sse/sse.html new file mode 100644 index 0000000..ab88b10 --- /dev/null +++ b/test/sse/sse.html @@ -0,0 +1,41 @@ + + + + + SSE test + + + + +

SSE test

+
+ +
+ + + \ No newline at end of file diff --git a/test/sse/sse_test.go b/test/sse/sse_test.go new file mode 100644 index 0000000..47f57a7 --- /dev/null +++ b/test/sse/sse_test.go @@ -0,0 +1,97 @@ +//后端代码 + +//注意 **我注释的代码,是不使用gin框架封装的Stream方法,也就是C.Stream(func())和C.ssevent(),只是C.Stream要改成for循环持续的从通道里面进行读,直到通道关闭,结束for循环** + +package main + +import ( + "catface/app/service/nlp/glm" + _ "catface/bootstrap" + "fmt" + "io" + "testing" + // "time" + + "github.com/gin-gonic/gin" + +) + +func SSE(c *gin.Context) { + // 设置响应头,告诉前端适用event-stream事件流交互 + //c.Writer.Header().Set("Content-Type", "text/event-stream") + //c.Writer.Header().Set("Cache-Control", "no-cache") + //c.Writer.Header().Set("Connection", "keep-alive") + + // 判断是否支持sse + //w := c.Writer + //flusher, _ := w.(http.Flusher) + query := c.Query("query") + + // 接收前端页面关闭连接通知 + closeNotify := c.Request.Context().Done() + + // 开启协程监听前端页面是否关闭了连接,关闭连接会触发此方法 + go func() { + <-closeNotify + fmt.Println("SSE关闭了") + return + }() + + //新建一个通道,用于数据接收和响应 + Chan := make(chan string) + + // 异步接收GPT响应,然后把响应的数据发送到通道Chan + go func() { + err := glm.ChatStream(query, Chan) + if err != nil { + fmt.Println("Error", err) + } + + close(Chan) + }() + + // gin框架封装的stream,会持续的调用这个func方法,记得返回true;返回false代表结束调用func方法 + c.Stream(func(w io.Writer) bool { + select { + case i, ok := <-Chan: + if !ok { + return false + } + c.SSEvent("chat", i) // c.SSEvent会自动修改响应头为事件流,并发送”test“事件流给前端监听”test“的回调方法 + //flusher.Flush() // 确保立即发送 + return true + case <-closeNotify: + fmt.Println("SSE关闭了") + return false + } + }) +} + +func TestSSE(t *testing.T) { + engine := gin.Default() + // 设置跨域中间件 + engine.Use(func(context *gin.Context) { + origin := context.GetHeader("Origin") + // 允许 Origin 字段中的域发送请求 + context.Writer.Header().Add("Access-Control-Allow-Origin", origin) // 这边我的前端页面在63342,会涉及跨域,这个根据自己情况设置,或者直接设置为”*“,放行所有的 + // 设置预验请求有效期为 86400 秒 + context.Writer.Header().Set("Access-Control-Max-Age", "86400") + // 设置允许请求的方法 + context.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, U`PDATE, PATCH") + // 设置允许请求的 Header + context.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Apitoken") + // 设置拿到除基本字段外的其他字段,如上面的Apitoken, 这里通过引用Access-Control-Expose-Headers,进行配置,效果是一样的。 + context.Writer.Header().Set("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Headers") + // 配置是否可以带认证信息 + context.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + // OPTIONS请求返回200 + if context.Request.Method == "OPTIONS" { + fmt.Println(context.Request.Header) + context.AbortWithStatus(200) + } else { + context.Next() + } + }) + engine.GET("/admin/rag/default_talk", SSE) // TIP 记得适用get请求,我用post前端报404,资料说是SSE只支持get请求 + engine.Run(":20201") +}