From f8a4f1d042838d38f24d348ecb543dfed7baf98b Mon Sep 17 00:00:00 2001 From: Havoc412 <2993167370@qq.com> Date: Tue, 19 Nov 2024 03:21:28 +0800 Subject: [PATCH] better Embedding with All query --- app/http/controller/web/rag_controller.go | 34 ++++++++++------------- app/model/rag_websocket_result.go | 24 ++++++++++++++++ app/service/nlp/embedding.go | 2 +- app/utils/llm_factory/glm_client.go | 33 ++++++++++++++++++---- 4 files changed, 66 insertions(+), 27 deletions(-) create mode 100644 app/model/rag_websocket_result.go diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index a29cc30..879c2e5 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -3,12 +3,12 @@ package web import ( "catface/app/global/errcode" "catface/app/global/variable" + "catface/app/model" "catface/app/model_es" "catface/app/service/nlp" "catface/app/utils/llm_factory" "catface/app/utils/micro_service" "catface/app/utils/response" - "encoding/json" "io" "net/http" @@ -72,7 +72,7 @@ func (r *Rag) ChatSSE(context *gin.Context) { } // 1. query embedding - embedding, ok := nlp.GetEmbedding(query) + embedding, ok := nlp.GetEmbedding([]string{query}) if !ok { code := errcode.ErrPythonService response.Fail(context, code, errcode.ErrMsg[code], "") @@ -138,17 +138,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "") return } - defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。 - tokenMsg := struct { - Type string `json:"type"` - Token string `json:"token"` - }{ - Type: "token", - Token: token, - } + defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。 + tokenMsg := model.CreateNlpWebSocketResult("token", token) - tokenBytes, _ := json.Marshal(tokenMsg) - err := ws.WriteMessage(websocket.TextMessage, tokenBytes) + err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err)) } @@ -158,7 +151,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { // 0-2. 测试 Python 微服务是否启动 if !micro_service.TestLinkPythonService() { code := errcode.ErrPythonServierDown - err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) + err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -166,10 +159,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { } // 0-3. 从 GLM_HUB 中获取一个可用的 glm client; - client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub) + clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub) if ercode != 0 { variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err)) - err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode])) + err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -177,10 +170,11 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { } // 1. query embedding - embedding, ok := nlp.GetEmbedding(query) + clientInfo.AddQuery(query) + embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys) if !ok { code := errcode.ErrPythonServierDown - err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) + err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -193,7 +187,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { variable.ZapLog.Error("ES TopK error", zap.Error(err)) code := errcode.ErrNoDocFound - err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) + err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -205,7 +199,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { ch := make(chan string) // TIP 建立通道。 go func() { - err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client) + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, clientInfo.Client) if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } @@ -219,7 +213,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { return } // variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c)) - err := ws.WriteMessage(websocket.TextMessage, []byte(c)) + err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal()) if err != nil { return } diff --git a/app/model/rag_websocket_result.go b/app/model/rag_websocket_result.go new file mode 100644 index 0000000..cb88984 --- /dev/null +++ b/app/model/rag_websocket_result.go @@ -0,0 +1,24 @@ +package model + +import "encoding/json" + +func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult { + if t == "" { + t = "chat" + } + + return &NlpWebSocketResult{ + Type: t, + Data: data, + } +} + +type NlpWebSocketResult struct { + Type string `json:"type"` + Data any `json:"data"` +} + +func (n *NlpWebSocketResult) JsonMarshal() []byte { + data, _ := json.Marshal(n) + return data +} diff --git a/app/service/nlp/embedding.go b/app/service/nlp/embedding.go index 5922243..fbddac5 100644 --- a/app/service/nlp/embedding.go +++ b/app/service/nlp/embedding.go @@ -14,7 +14,7 @@ type EmbeddingRes struct { Embedding []float64 `json:"embedding"` } -func GetEmbedding(text string) ([]float64, bool) { +func GetEmbedding(text []string) ([]float64, bool) { body := map[string]interface{}{ "text": text, } diff --git a/app/utils/llm_factory/glm_client.go b/app/utils/llm_factory/glm_client.go index 06a38a8..68a2d0f 100644 --- a/app/utils/llm_factory/glm_client.go +++ b/app/utils/llm_factory/glm_client.go @@ -20,8 +20,9 @@ type GlmClientHub struct { } type ClientInfo struct { - Client *zhipu.ChatCompletionService - LastUsed time.Time + Client *zhipu.ChatCompletionService + UserQuerys []string + LastUsed time.Time } func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub { @@ -50,10 +51,10 @@ const ( * @param {string} token: // TODO 如何在 token 中保存信息? * @return {*} */ -func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.ChatCompletionService, code int) { +func (g *GlmClientHub) GetOneGlmClientInfo(token string, mode int) (clientInfo *ClientInfo, code int) { if info, ok := g.Clients[token]; ok { info.LastUsed = time.Now() // INFO 刷新生命周期 - return info.Client, 0 + return info, 0 } // 空闲数检查 @@ -70,7 +71,7 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch code = errcode.ErrGlmNewClientFail return } - client = preClient.ChatCompletion(g.DefaultModelName) + client := preClient.ChatCompletion(g.DefaultModelName) if mode == GlmModeKnowledgeHub { client.AddMessage(zhipu.ChatCompletionMessage{ @@ -79,13 +80,28 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch }) } - g.Clients[token] = &ClientInfo{ + clientInfo = &ClientInfo{ Client: client, LastUsed: time.Now(), } + g.Clients[token] = clientInfo return } +/** + * @description: 获取并返回 ClientInfo 的 Client 和 code。 + * @param {string} token + * @param {int} mode + * @return {(*zhipu.ChatCompletionService, int)} + */ +func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (*zhipu.ChatCompletionService, int) { + clientInfo, code := g.GetOneGlmClientInfo(token, mode) + if clientInfo == nil || code != 0 { + return nil, code + } + return clientInfo.Client, code +} + // cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client func (g *GlmClientHub) cleanupRoutine() { ticker := time.NewTicker(10 * time.Minute) @@ -114,3 +130,8 @@ func (g *GlmClientHub) ReleaseOneGlmClient(token string) { delete(g.Clients, token) g.MaxIdle += 1 } + +// TAG ClientInfo +func (c *ClientInfo) AddQuery(query string) { + c.UserQuerys = append(c.UserQuerys, query) +}