diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index 065e7a8..753da5f 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -1,12 +1,12 @@ package web import ( - "catface/app/global/consts" "catface/app/global/errcode" "catface/app/global/variable" "catface/app/model_es" "catface/app/service/nlp" "catface/app/utils/response" + "io" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -15,9 +15,40 @@ import ( type Rag struct { } -func (r *Rag) Chat(context *gin.Context) { +// v1 Http-POST 版本 +// func (r *Rag) Chat(context *gin.Context) { +// // 1. query embedding +// query := context.GetString(consts.ValidatorPrefix + "query") +// embedding, ok := nlp.GetEmbedding(query) +// if !ok { +// code := errcode.ErrPythonService +// response.Fail(context, code, errcode.ErrMsg[code], "") +// return +// } + +// // 2. ES TopK +// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1) +// if err != nil || len(docs) == 0 { +// variable.ZapLog.Error("ES TopK error", zap.Error(err)) + +// code := errcode.ErrNoDocFound +// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) +// } + +// // 3. LLM answer +// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil { +// response.Success(context, consts.CurdStatusOkMsg, gin.H{ +// "answer": answer, +// }) +// } else { +// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "") +// } +// } + +func (r *Rag) ChatSSE(context *gin.Context) { + query := context.Query("query") + // 1. query embedding - query := context.GetString(consts.ValidatorPrefix + "query") embedding, ok := nlp.GetEmbedding(query) if !ok { code := errcode.ErrPythonService @@ -34,14 +65,31 @@ func (r *Rag) Chat(context *gin.Context) { response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) } + // UPDATE + closeEventFromVue := context.Request.Context().Done() + ch := make(chan string) // TIP 建立通道。 + // 3. LLM answer - if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil { - response.Success(context, consts.CurdStatusOkMsg, gin.H{ - "answer": answer, - }) - } else { - response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "") - } + go func() { + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch) + if err != nil { + variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) + } + close(ch) + }() + + context.Stream(func(w io.Writer) bool { + select { + case c, ok := <-ch: + if !ok { + return false + } + context.SSEvent("chat", c) + return true + case <-closeEventFromVue: + return false + } + }) } func (r *Rag) HelpDetectCat(context *gin.Context) { diff --git a/app/http/validator/web/rag/chat.go b/app/http/validator/web/rag/chat.go index 2b9adbd..94cdada 100644 --- a/app/http/validator/web/rag/chat.go +++ b/app/http/validator/web/rag/chat.go @@ -25,7 +25,7 @@ func (c Chat) CheckParams(context *gin.Context) { if extraAddBindDataContext == nil { response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "") } else { - (&web.Rag{}).Chat(extraAddBindDataContext) + (&web.Rag{}).ChatSSE(extraAddBindDataContext) } } diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go index 6be43c1..4d7cfaf 100644 --- a/app/service/nlp/func.go +++ b/app/service/nlp/func.go @@ -14,7 +14,7 @@ func GenerateTitle(content string) string { } // ChatKnoledgeRAG 使用 RAG 模型进行知识问答 -func ChatKnoledgeRAG(doc, query string) (string, error) { +func ChatKnoledgeRAG(doc, query string, ch chan<- string) error { // 读取配置文件中的 KnoledgeRAG 模板 promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG") @@ -23,10 +23,10 @@ func ChatKnoledgeRAG(doc, query string) (string, error) { message = strings.Replace(message, "{context}", doc, -1) // 调用聊天接口 - response, err := glm.Chat(message) + err := glm.ChatStream(message, ch) if err != nil { - return "", fmt.Errorf("调用聊天接口失败: %w", err) + return fmt.Errorf("调用聊天接口失败: %w", err) } - return response, nil + return nil } diff --git a/app/service/nlp/glm/glm.go b/app/service/nlp/glm/glm.go index a03a70f..22ddda3 100644 --- a/app/service/nlp/glm/glm.go +++ b/app/service/nlp/glm/glm.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/yankeguo/zhipu" + ) // ChatWithGLM 封装了与GLM模型进行对话的逻辑 @@ -24,3 +25,24 @@ func Chat(message string) (string, error) { return res.Choices[0].Message.Content, nil } + +// ChatStream 接收一个消息和一个通道,将流式响应发送到通道中 +func ChatStream(message string, ch chan<- string) error { + service := variable.GlmClient.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{Role: "user", Content: message}). + SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + content := chunk.Choices[0].Delta.Content + if content != "" { + ch <- content // 将内容发送到通道 + } + return nil + }) + + // 执行服务调用 + _, err := service.Do(context.Background()) + if err != nil { + return err + } + + return nil +} diff --git a/routers/web.go b/routers/web.go index 2cff193..ce86326 100644 --- a/routers/web.go +++ b/routers/web.go @@ -152,7 +152,7 @@ func InitWebRouter() *gin.Engine { rag := backend.Group("rag") { - rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat")) + rag.GET("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat")) } search := backend.Group("search") diff --git a/test/sse/sse.html b/test/sse/sse.html new file mode 100644 index 0000000..ab88b10 --- /dev/null +++ b/test/sse/sse.html @@ -0,0 +1,41 @@ + + + +
+