diff --git a/app/global/errcode/nlp.go b/app/global/errcode/nlp.go index 56d0e24..d63d426 100644 --- a/app/global/errcode/nlp.go +++ b/app/global/errcode/nlp.go @@ -4,15 +4,22 @@ const ( ErrNoContent = ErrNlp + iota ErrNoDocFound ErrPythonServierDown + ErrGlmBusy + ErrGlmHistoryLoss + ErrGlmNewClientFail ) func NlpMsgInit(m msg) { m[ErrNoContent] = "内容为空" m[ErrNoDocFound] = "没有找到相关文档" + m[ErrGlmNewClientFail] = "GLM 新建客户端失败" } func NlpMsgUserInit(m msg) { m[ErrNoContent] = "请输入内容" m[ErrNoDocFound] = "小护没有在知识库中找到相关文档。😿" m[ErrPythonServierDown] = "小护的🐍python服务挂了,此功能暂时无法使用。😿" + m[ErrGlmBusy] = "现在有太多人咨询小护,请稍后再来。" + m[ErrGlmHistoryLoss] = "抱歉!小护找不到之前的会话记录了,我们重新开始新的对话吧。" + // m[ErrGlmNewClientFail] = "小护新建客户端失败了,请稍后再来。" } diff --git a/app/global/variable/variable.go b/app/global/variable/variable.go index 355b1cc..a60939f 100644 --- a/app/global/variable/variable.go +++ b/app/global/variable/variable.go @@ -2,6 +2,7 @@ package variable import ( "catface/app/global/my_errors" + "catface/app/utils/llm_factory" "catface/app/utils/snow_flake/snowflake_interf" "catface/app/utils/yml_config/ymlconfig_interf" "log" @@ -10,7 +11,6 @@ import ( "github.com/casbin/casbin/v2" "github.com/elastic/go-elasticsearch/v8" - "github.com/yankeguo/zhipu" "go.uber.org/zap" "gorm.io/gorm" ) @@ -44,8 +44,8 @@ var ( //casbin 全局操作指针 Enforcer *casbin.SyncedEnforcer - // GLM 全局客户端 - GlmClient *zhipu.Client + // GLM 全局客户端集中管理 + GlmClientHub *llm_factory.GlmClientHub // ES 全局客户端 ElasticClient *elasticsearch.Client diff --git a/app/http/controller/web/nlp_controller.go b/app/http/controller/web/nlp_controller.go index 9d5f90f..4e5ee3c 100644 --- a/app/http/controller/web/nlp_controller.go +++ b/app/http/controller/web/nlp_controller.go @@ -2,7 +2,10 @@ package web import ( "catface/app/global/consts" + "catface/app/global/errcode" + "catface/app/global/variable" "catface/app/service/nlp" + "catface/app/utils/llm_factory" "catface/app/utils/response" "github.com/gin-gonic/gin" @@ -14,7 +17,14 @@ type Nlp struct { func (n *Nlp) Title(context *gin.Context) { content := context.GetString(consts.ValidatorPrefix + "content") - newTitle := nlp.GenerateTitle(content) + tempGlmKey := variable.SnowFlake.GetIdAsString() + client, ercode := variable.GlmClientHub.GetOneGlmClient(tempGlmKey, llm_factory.GlmModeSimple) + if ercode > 0 { + response.Fail(context, ercode, errcode.ErrMsg[ercode], errcode.ErrMsgForUser[ercode]) + } + defer variable.GlmClientHub.ReleaseOneGlmClient(tempGlmKey) // 临时使用,用完就释放。 + + newTitle := nlp.GenerateTitle(content, client) if newTitle != "" { response.Success(context, consts.CurdStatusOkMsg, gin.H{"title": newTitle}) } else { diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index ef291c0..a29cc30 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -5,7 +5,10 @@ import ( "catface/app/global/variable" "catface/app/model_es" "catface/app/service/nlp" + "catface/app/utils/llm_factory" + "catface/app/utils/micro_service" "catface/app/utils/response" + "encoding/json" "io" "net/http" @@ -49,6 +52,24 @@ type Rag struct { func (r *Rag) ChatSSE(context *gin.Context) { query := context.Query("query") + token := context.Query("token") + + // 0-1. 测试 python + if !micro_service.TestLinkPythonService() { + code := errcode.ErrPythonService + response.Fail(context, code, errcode.ErrMsg[code], "") + return + } + + // 0-2. 获取一个 GLM Client + if token == "" { + token = variable.SnowFlake.GetIdAsString() + } + client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub) + if ercode != 0 { + response.Fail(context, ercode, errcode.ErrMsg[ercode], errcode.ErrMsgForUser[ercode]) + return + } // 1. query embedding embedding, ok := nlp.GetEmbedding(query) @@ -73,7 +94,7 @@ func (r *Rag) ChatSSE(context *gin.Context) { // 3. LLM answer go func() { - err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch) + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client) if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } @@ -104,15 +125,56 @@ var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss func (r *Rag) ChatWebSocket(context *gin.Context) { query := context.Query("query") + token := context.Query("token") - // 0. 协议升级 + if token == "" { + token = variable.SnowFlake.GetIdAsString() + } + + // 0-1. 协议升级 ws, err := upgrader.Upgrade(context.Writer, context.Request, nil) if err != nil { variable.ZapLog.Error("OnOpen error", zap.Error(err)) response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "") return } - defer ws.Close() + defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。 + tokenMsg := struct { + Type string `json:"type"` + Token string `json:"token"` + }{ + Type: "token", + Token: token, + } + + tokenBytes, _ := json.Marshal(tokenMsg) + err := ws.WriteMessage(websocket.TextMessage, tokenBytes) + if err != nil { + variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err)) + } + ws.Close() + }() + + // 0-2. 测试 Python 微服务是否启动 + if !micro_service.TestLinkPythonService() { + code := errcode.ErrPythonServierDown + err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) + if err != nil { + variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) + } + return + } + + // 0-3. 从 GLM_HUB 中获取一个可用的 glm client; + client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub) + if ercode != 0 { + variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err)) + err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode])) + if err != nil { + variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) + } + return + } // 1. query embedding embedding, ok := nlp.GetEmbedding(query) @@ -143,7 +205,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { ch := make(chan string) // TIP 建立通道。 go func() { - err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch) + err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client) if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } diff --git a/app/http/validator/web/rag/chat.go b/app/http/validator/web/rag/chat.go index febc5ea..b2da586 100644 --- a/app/http/validator/web/rag/chat.go +++ b/app/http/validator/web/rag/chat.go @@ -12,7 +12,7 @@ import ( // INFO 虽然起名为 Chat,但是默认就会去查询 知识库,也就是不作为一般的 LLM-chat 来使用。 type Chat struct { Query string `form:"query" json:"query" binding:"required"` - // TODO 这里还需要处理一下历史记录? + Token string `form:"token" json:"token"` // UPDATE 暂时不想启用 user 的 token,就先单独处理。 } func (c Chat) CheckParams(context *gin.Context) { diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go index f24a6be..653dc25 100644 --- a/app/service/nlp/func.go +++ b/app/service/nlp/func.go @@ -5,16 +5,18 @@ import ( "catface/app/service/nlp/glm" "fmt" "strings" + + "github.com/yankeguo/zhipu" ) -func GenerateTitle(content string) string { +func GenerateTitle(content string, client *zhipu.ChatCompletionService) string { message := variable.PromptsYml.GetString("Prompt.Title") + content - title, _ := glm.Chat(message) + title, _ := glm.Chat(message, client) return title } // ChatKnoledgeRAG 使用 RAG 模型进行知识问答 -func ChatKnoledgeRAG(doc, query string, ch chan<- string) error { +func ChatKnoledgeRAG(doc, query string, ch chan<- string, client *zhipu.ChatCompletionService) error { // 读取配置文件中的 KnoledgeRAG 模板 promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG") @@ -24,7 +26,7 @@ func ChatKnoledgeRAG(doc, query string, ch chan<- string) error { // 调用聊天接口 // err := glm.ChatStream(message, ch) - err := glm.BufferedChatStream(message, ch) + err := glm.BufferedChatStream(message, ch, client) if err != nil { return fmt.Errorf("调用聊天接口失败: %w", err) } diff --git a/app/service/nlp/glm/glm.go b/app/service/nlp/glm/glm.go index ba7b731..29ea303 100644 --- a/app/service/nlp/glm/glm.go +++ b/app/service/nlp/glm/glm.go @@ -4,19 +4,20 @@ import ( "catface/app/global/variable" "context" "errors" + "fmt" "strings" "time" "github.com/yankeguo/zhipu" + "go.uber.org/zap" ) // ChatWithGLM 封装了与GLM模型进行对话的逻辑 -func Chat(message string) (string, error) { - service := variable.GlmClient.ChatCompletion("glm-4-flash"). - AddMessage(zhipu.ChatCompletionMessage{ - Role: "user", - Content: message, - }) +func Chat(message string, client *zhipu.ChatCompletionService) (string, error) { + service := client.AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: message, + }) res, err := service.Do(context.Background()) if err != nil { @@ -28,9 +29,8 @@ func Chat(message string) (string, error) { } // ChatStream 接收一个消息和一个通道,将流式响应发送到通道中 -func ChatStream(message string, ch chan<- string) error { - service := variable.GlmClient.ChatCompletion("glm-4-flash"). - AddMessage(zhipu.ChatCompletionMessage{Role: "user", Content: message}). +func ChatStream(message string, ch chan<- string, client *zhipu.ChatCompletionService) error { + service := client.AddMessage(zhipu.ChatCompletionMessage{Role: "user", Content: message}). SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { content := chunk.Choices[0].Delta.Content if content != "" { @@ -39,22 +39,30 @@ func ChatStream(message string, ch chan<- string) error { return nil }) + // Test + messages := client.GetMessages() + for id, message := range messages { + variable.ZapLog.Info(fmt.Sprintf("message-%d", id+1), zap.String("message", message.(zhipu.ChatCompletionMessage).Role), zap.String("content", message.(zhipu.ChatCompletionMessage).Content)) + } + // 执行服务调用 - _, err := service.Do(context.Background()) + res, err := service.Do(context.Background()) if err != nil { return err } + // 增加 AI 回答的消息记录。 + client.AddMessage(zhipu.ChatCompletionMessage{Role: "assistant", Content: res.Choices[0].Message.Content}) return nil } // 带缓冲机制的 ChatStream;计数 & 计时 双判定。 -func BufferedChatStream(message string, ch chan<- string) error { +func BufferedChatStream(message string, ch chan<- string, client *zhipu.ChatCompletionService) error { bufferedCh := make(chan string) // 带缓冲的通道,缓冲大小为10 timer := time.NewTimer(500 * time.Millisecond) // 定时器,500毫秒 go func() { - err := ChatStream(message, bufferedCh) + err := ChatStream(message, bufferedCh, client) if err != nil { return } diff --git a/app/utils/llm_factory/glm_client.go b/app/utils/llm_factory/glm_client.go new file mode 100644 index 0000000..06a38a8 --- /dev/null +++ b/app/utils/llm_factory/glm_client.go @@ -0,0 +1,116 @@ +package llm_factory + +import ( + "catface/app/global/errcode" + "time" + + "github.com/yankeguo/zhipu" +) + +// INFO 维护 GLM Client 与用户之间的客户端消息队列,也就是在 "github.com/yankeguo/zhipu" 的基础上实现一层封装。 + +type GlmClientHub struct { + MaxIdle int + MaxActive int + ApiKey string + DefaultModelName string + InitPrompt string + Clients map[string]*ClientInfo + LifeTime time.Duration +} + +type ClientInfo struct { + Client *zhipu.ChatCompletionService + LastUsed time.Time +} + +func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub { + hub := &GlmClientHub{ + MaxIdle: maxIdle, + MaxActive: maxActive, + ApiKey: apiKey, + DefaultModelName: defaultModelName, + InitPrompt: initPrompt, + Clients: make(map[string]*ClientInfo), + LifeTime: time.Duration(lifetime) * time.Second, + } + go hub.cleanupRoutine() // 启动定时器清理过期会话。 + return hub +} + +const ( + GlmModeSimple = iota + GlmModeKnowledgeHub +) + +/** + * @description: 鉴权用户之后,根据其 ID 来从 map池 里获取之前的连接。 + * // UPDATE 现在只是单用户单连接(也就是只支持“同时只有一个对话”),之后可以考虑扩展【消息队列】的封装方式。 + * 默认启用的是 没有预设的 prompt 的空。 + * @param {string} token: // TODO 如何在 token 中保存信息? + * @return {*} + */ +func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.ChatCompletionService, code int) { + if info, ok := g.Clients[token]; ok { + info.LastUsed = time.Now() // INFO 刷新生命周期 + return info.Client, 0 + } + + // 空闲数检查 + if g.MaxIdle > 0 { + g.MaxIdle -= 1 + } else { + code = errcode.ErrGlmBusy + return + } + + // Client Init + preClient, err := zhipu.NewClient(zhipu.WithAPIKey(g.ApiKey)) + if err != nil { + code = errcode.ErrGlmNewClientFail + return + } + client = preClient.ChatCompletion(g.DefaultModelName) + + if mode == GlmModeKnowledgeHub { + client.AddMessage(zhipu.ChatCompletionMessage{ + Role: zhipu.RoleSystem, // TIP 使用 System 角色来初始化对话 + Content: g.InitPrompt, + }) + } + + g.Clients[token] = &ClientInfo{ + Client: client, + LastUsed: time.Now(), + } + return +} + +// cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client +func (g *GlmClientHub) cleanupRoutine() { + ticker := time.NewTicker(10 * time.Minute) + for range ticker.C { + g.cleanupClients() + } +} + +// cleanupClients 清理超过 1 小时未使用的 Client +func (g *GlmClientHub) cleanupClients() { + now := time.Now() + for token, info := range g.Clients { + if now.Sub(info.LastUsed) > g.LifeTime { + delete(g.Clients, token) + g.MaxIdle += 1 + } + } +} + +/** + * @description: 显式地释放资源。 + * @param {string} token + * @return {*} + */ +func (g *GlmClientHub) ReleaseOneGlmClient(token string) { + delete(g.Clients, token) + g.MaxIdle += 1 +} diff --git a/app/utils/micro_service/micro_service.go b/app/utils/micro_service/micro_service.go index 187c176..8c9c526 100644 --- a/app/utils/micro_service/micro_service.go +++ b/app/utils/micro_service/micro_service.go @@ -2,10 +2,18 @@ package micro_service import ( "catface/app/global/variable" + "context" "fmt" "strings" + + "github.com/carlmjohnson/requests" ) +func TestLinkPythonService() bool { + err := requests.URL(FetchPythonServiceUrl("link_test")).Fetch(context.Background()) + return err == nil +} + func FetchPythonServiceUrl(url string) string { // 检查 url 是否以 / 开头,如果是则去掉开头的 / if strings.HasPrefix(url, "/") { diff --git a/app/utils/snow_flake/snow_flake.go b/app/utils/snow_flake/snow_flake.go index 9182b1b..d5ed83b 100644 --- a/app/utils/snow_flake/snow_flake.go +++ b/app/utils/snow_flake/snow_flake.go @@ -4,6 +4,7 @@ import ( "catface/app/global/consts" "catface/app/global/variable" "catface/app/utils/snow_flake/snowflake_interf" + "strconv" "sync" "time" ) @@ -45,3 +46,9 @@ func (s *snowflake) GetId() int64 { r := (now-consts.StartTimeStamp)<