From 3413d52316878bf51ad00a523a97e5bf3f36be42 Mon Sep 17 00:00:00 2001 From: Havoc412 <2993167370@qq.com> Date: Sat, 16 Nov 2024 18:18:07 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=8F=20finish=20RAG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/global/errcode/code.go | 1 + app/global/errcode/msg.go | 1 + app/global/errcode/websocket.go | 9 +++ app/http/controller/web/rag_controller.go | 68 +++++++++++++++++++++++ app/http/validator/web/rag/chat.go | 2 +- app/service/nlp/func.go | 3 +- app/service/nlp/glm/glm.go | 42 +++++++++++++- config/config.yml | 2 +- 8 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 app/global/errcode/websocket.go diff --git a/app/global/errcode/code.go b/app/global/errcode/code.go index 522463f..da4ddd1 100644 --- a/app/global/errcode/code.go +++ b/app/global/errcode/code.go @@ -8,6 +8,7 @@ const ( ErrNlp ErrKnowledge ErrSubService + ErrWebSocket ) const ( diff --git a/app/global/errcode/msg.go b/app/global/errcode/msg.go index be49f11..6d70cb1 100644 --- a/app/global/errcode/msg.go +++ b/app/global/errcode/msg.go @@ -15,6 +15,7 @@ func init() { NlpMsgInit(ErrMsg) KnowledgeMsgInit(ErrMsg) SubServiceMsgInit(ErrMsg) + WsMsgInit(ErrMsg) // INGO ErrMsgForUser = make(msg) diff --git a/app/global/errcode/websocket.go b/app/global/errcode/websocket.go new file mode 100644 index 0000000..75bba99 --- /dev/null +++ b/app/global/errcode/websocket.go @@ -0,0 +1,9 @@ +package errcode + +const ( + ErrWebsocketUpgradeFail = ErrWebSocket + iota +) + +func WsMsgInit(m msg) { + m[ErrWebsocketUpgradeFail] = "websocket升级失败" +} diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index 753da5f..34e85ac 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -7,8 +7,10 @@ import ( "catface/app/service/nlp" "catface/app/utils/response" "io" + "net/http" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "go.uber.org/zap" ) @@ -92,6 +94,72 @@ func (r *Rag) ChatSSE(context *gin.Context) { }) } +var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // info 在生产环境中可能需要更安全的检查 + }, +} + +func (r *Rag) ChatWebSocket(context *gin.Context) { + query := context.Query("query") + + // 0. 协议升级 + ws, err := upgrader.Upgrade(context.Writer, context.Request, nil) + if err != nil { + variable.ZapLog.Error("OnOpen error", zap.Error(err)) + response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "") + return + } + defer ws.Close() + + // 1. query embedding + embedding, ok := nlp.GetEmbedding(query) + if !ok { + code := errcode.ErrPythonService + response.Fail(context, code, errcode.ErrMsg[code], "") + return + } + + // 2. ES TopK + docs, err := model_es.CreateDocESFactory().TopK(embedding, 1) + if err != nil || len(docs) == 0 { + variable.ZapLog.Error("ES TopK error", zap.Error(err)) + + code := errcode.ErrNoDocFound + response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) + } + + // 3. + closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。 + ch := make(chan string) // TIP 建立通道。 + + go func() { + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch) + if err != nil { + variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) + } + close(ch) // 这里 close,使得下方 for 结束。 + }() + + for { + select { + case c, ok := <-ch: + if !ok { + return + } + // variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c)) + err := ws.WriteMessage(websocket.TextMessage, []byte(c)) + if err != nil { + return + } + case <-closeEventFromVue: + return + } + } +} + func (r *Rag) HelpDetectCat(context *gin.Context) { // TODO } diff --git a/app/http/validator/web/rag/chat.go b/app/http/validator/web/rag/chat.go index 94cdada..febc5ea 100644 --- a/app/http/validator/web/rag/chat.go +++ b/app/http/validator/web/rag/chat.go @@ -25,7 +25,7 @@ func (c Chat) CheckParams(context *gin.Context) { if extraAddBindDataContext == nil { response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "") } else { - (&web.Rag{}).ChatSSE(extraAddBindDataContext) + (&web.Rag{}).ChatWebSocket(extraAddBindDataContext) } } diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go index 4d7cfaf..f24a6be 100644 --- a/app/service/nlp/func.go +++ b/app/service/nlp/func.go @@ -23,7 +23,8 @@ func ChatKnoledgeRAG(doc, query string, ch chan<- string) error { message = strings.Replace(message, "{context}", doc, -1) // 调用聊天接口 - err := glm.ChatStream(message, ch) + // err := glm.ChatStream(message, ch) + err := glm.BufferedChatStream(message, ch) if err != nil { return fmt.Errorf("调用聊天接口失败: %w", err) } diff --git a/app/service/nlp/glm/glm.go b/app/service/nlp/glm/glm.go index 22ddda3..ba7b731 100644 --- a/app/service/nlp/glm/glm.go +++ b/app/service/nlp/glm/glm.go @@ -4,9 +4,10 @@ import ( "catface/app/global/variable" "context" "errors" + "strings" + "time" "github.com/yankeguo/zhipu" - ) // ChatWithGLM 封装了与GLM模型进行对话的逻辑 @@ -46,3 +47,42 @@ func ChatStream(message string, ch chan<- string) error { return nil } + +// 带缓冲机制的 ChatStream;计数 & 计时 双判定。 +func BufferedChatStream(message string, ch chan<- string) error { + bufferedCh := make(chan string) // 带缓冲的通道,缓冲大小为10 + timer := time.NewTimer(500 * time.Millisecond) // 定时器,500毫秒 + + go func() { + err := ChatStream(message, bufferedCh) + if err != nil { + return + } + close(bufferedCh) + }() + + var buffer strings.Builder + for { + select { + case c, ok := <-bufferedCh: + if !ok { + if buffer.Len() > 0 { + ch <- buffer.String() + } + return nil // 依靠这里停止函数。 + } + buffer.WriteString(c) + if buffer.Len() >= 10 { + ch <- buffer.String() + buffer.Reset() + timer.Reset(500 * time.Millisecond) + } + case <-timer.C: + if buffer.Len() > 0 { + ch <- buffer.String() + buffer.Reset() + } + timer.Reset(500 * time.Millisecond) + } + } +} diff --git a/config/config.yml b/config/config.yml index cdc6d57..9b1ab71 100644 --- a/config/config.yml +++ b/config/config.yml @@ -44,7 +44,7 @@ Logs: Compress: false #日志备份时,是否进行压缩 Websocket: #该服务与Http具有相同的ip、端口,因此不需要额外设置端口 - Start: 0 #默认不启动该服务(1=启动;0=不启动) + Start: 0 # 默认不启动该服务(1=启动;0=不启动) WriteReadBufferSize: 20480 # 读写缓冲区分配字节,大概能存储 6800 多一点的文字 MaxMessageSize: 65535 # 从消息管道读取消息的最大字节 PingPeriod: 20 #心跳包频率,单位:秒