Havoc412 f263ead51d refactor(llm_factory): 新增 GLM 客户端未使用资源释放功能
- 在 NlpController 和 RagController 中添加了释放 GLM 客户端资源的逻辑
- 在 GlmClientHub 中增加了 UnavtiveOneGlmClient 方法,用于将客户端标记为未使用状态
- 优化了资源管理,提高了 GLM 客户端的利用率
2024-11-20 09:01:53 +08:00

257 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package web
import (
"catface/app/global/consts"
"catface/app/global/errcode"
"catface/app/global/variable"
"catface/app/model"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/service/rag/curd"
"catface/app/utils/llm_factory"
"catface/app/utils/micro_service"
"catface/app/utils/response"
"io"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
type Rag struct {
}
func (r *Rag) Release(context *gin.Context) {
token := context.GetString(consts.ValidatorPrefix + "token")
if ok := variable.GlmClientHub.ReleaseOneGlmClient(token); ok {
variable.ZapLog.Info("释放一个 GLM Client",
zap.String("token", token),
zap.String("当前空闲连接数", strconv.Itoa(variable.GlmClientHub.Idle)))
} else {
variable.ZapLog.Warn("尝试释放一个 GLM Client但是 token 无效",
zap.String("当前空闲连接数", strconv.Itoa(variable.GlmClientHub.Idle)))
}
response.Success(context, consts.CurdStatusOkMsg, "")
}
// v1 Http-POST 版本; chat 需要不使用 ch 的版本。
// func (r *Rag) Chat(context *gin.Context) {
// // 1. query embedding
// query := context.GetString(consts.ValidatorPrefix + "query")
// embedding, ok := nlp.GetEmbedding(query)
// if !ok {
// code := errcode.ErrPythonService
// response.Fail(context, code, errcode.ErrMsg[code], "")
// return
// }
// // 2. ES TopK
// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
// if err != nil || len(docs) == 0 {
// variable.ZapLog.Error("ES TopK error", zap.Error(err))
// code := errcode.ErrNoDocFound
// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
// }
// // 3. LLM answer
// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
// response.Success(context, consts.CurdStatusOkMsg, gin.H{
// "answer": answer,
// })
// } else {
// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
// }
// }
func (r *Rag) ChatSSE(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
// 0-1. 测试 python
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 0-2. 获取一个 GLM Client
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
response.Fail(context, ercode, errcode.ErrMsg[ercode], errcode.ErrMsgForUser[ercode])
return
}
defer variable.GlmClientHub.UnavtiveOneGlmClient(token) // INFO ws 结束时,取消 Avtive 的占用。
// 1. query embedding
embedding, ok := nlp.GetEmbedding([]string{query})
if !ok {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
}
// UPDATE
closeEventFromVue := context.Request.Context().Done()
ch := make(chan string) // TIP 建立通道。
// 3. LLM answer
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch)
}()
context.Stream(func(w io.Writer) bool {
select {
case c, ok := <-ch:
if !ok {
return false
}
context.SSEvent("chat", c)
return true
case <-closeEventFromVue:
return false
}
})
}
var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // info 在生产环境中可能需要更安全的检查
},
}
func (r *Rag) ChatWebSocket(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
// 0-1. 协议升级
ws, err := upgrader.Upgrade(context.Writer, context.Request, nil)
if err != nil {
variable.ZapLog.Error("OnOpen error", zap.Error(err))
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
return
}
// 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
defer variable.GlmClientHub.UnavtiveOneGlmClient(token) // INFO ws 结束时,取消 Avtive 的占用。
// 1. query embedding
clientInfo.AddQuery(query)
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
if !ok {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 2. ES TopK // TODO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
// 0. 传递参考资料的信息
docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
}
// 1. 传递 token 信息; // UPDATE 临时方案
tokenMsg := model.CreateNlpWebSocketResult("token", token)
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
}
ws.Close()
}()
// 3.
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
ch := make(chan string) // TIP 建立通道。
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, clientInfo.Client)
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch) // 这里 close使得下方 for 结束。
}()
for {
select {
case c, ok := <-ch:
if !ok {
return
}
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
if err != nil {
return
}
case <-closeEventFromVue:
return
}
}
}
func (r *Rag) HelpDetectCat(context *gin.Context) {
// TODO
}