better Embedding with All query

This commit is contained in:
Havoc412 2024-11-19 03:21:28 +08:00
parent 57f7e3f872
commit f8a4f1d042
4 changed files with 66 additions and 27 deletions

View File

@ -3,12 +3,12 @@ package web
import ( import (
"catface/app/global/errcode" "catface/app/global/errcode"
"catface/app/global/variable" "catface/app/global/variable"
"catface/app/model"
"catface/app/model_es" "catface/app/model_es"
"catface/app/service/nlp" "catface/app/service/nlp"
"catface/app/utils/llm_factory" "catface/app/utils/llm_factory"
"catface/app/utils/micro_service" "catface/app/utils/micro_service"
"catface/app/utils/response" "catface/app/utils/response"
"encoding/json"
"io" "io"
"net/http" "net/http"
@ -72,7 +72,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
} }
// 1. query embedding // 1. query embedding
embedding, ok := nlp.GetEmbedding(query) embedding, ok := nlp.GetEmbedding([]string{query})
if !ok { if !ok {
code := errcode.ErrPythonService code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "") response.Fail(context, code, errcode.ErrMsg[code], "")
@ -138,17 +138,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "") response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
return return
} }
defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。 defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
tokenMsg := struct { tokenMsg := model.CreateNlpWebSocketResult("token", token)
Type string `json:"type"`
Token string `json:"token"`
}{
Type: "token",
Token: token,
}
tokenBytes, _ := json.Marshal(tokenMsg) err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, tokenBytes)
if err != nil { if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err)) variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
} }
@ -158,7 +151,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
// 0-2. 测试 Python 微服务是否启动 // 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() { if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonServierDown code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil { if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
} }
@ -166,10 +159,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
} }
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client; // 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub) clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 { if ercode != 0 {
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err)) variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode])) err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
if err != nil { if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
} }
@ -177,10 +170,11 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
} }
// 1. query embedding // 1. query embedding
embedding, ok := nlp.GetEmbedding(query) clientInfo.AddQuery(query)
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
if !ok { if !ok {
code := errcode.ErrPythonServierDown code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil { if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
} }
@ -193,7 +187,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
variable.ZapLog.Error("ES TopK error", zap.Error(err)) variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound code := errcode.ErrNoDocFound
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code])) err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil { if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
} }
@ -205,7 +199,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
ch := make(chan string) // TIP 建立通道。 ch := make(chan string) // TIP 建立通道。
go func() { go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client) err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, clientInfo.Client)
if err != nil { if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
} }
@ -219,7 +213,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
return return
} }
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c)) // variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
err := ws.WriteMessage(websocket.TextMessage, []byte(c)) err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
if err != nil { if err != nil {
return return
} }

View File

@ -0,0 +1,24 @@
package model
import "encoding/json"
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
if t == "" {
t = "chat"
}
return &NlpWebSocketResult{
Type: t,
Data: data,
}
}
type NlpWebSocketResult struct {
Type string `json:"type"`
Data any `json:"data"`
}
func (n *NlpWebSocketResult) JsonMarshal() []byte {
data, _ := json.Marshal(n)
return data
}

View File

@ -14,7 +14,7 @@ type EmbeddingRes struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
func GetEmbedding(text string) ([]float64, bool) { func GetEmbedding(text []string) ([]float64, bool) {
body := map[string]interface{}{ body := map[string]interface{}{
"text": text, "text": text,
} }

View File

@ -20,8 +20,9 @@ type GlmClientHub struct {
} }
type ClientInfo struct { type ClientInfo struct {
Client *zhipu.ChatCompletionService Client *zhipu.ChatCompletionService
LastUsed time.Time UserQuerys []string
LastUsed time.Time
} }
func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub { func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub {
@ -50,10 +51,10 @@ const (
* @param {string} token // TODO 如何在 token 中保存信息? * @param {string} token // TODO 如何在 token 中保存信息?
* @return {*} * @return {*}
*/ */
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.ChatCompletionService, code int) { func (g *GlmClientHub) GetOneGlmClientInfo(token string, mode int) (clientInfo *ClientInfo, code int) {
if info, ok := g.Clients[token]; ok { if info, ok := g.Clients[token]; ok {
info.LastUsed = time.Now() // INFO 刷新生命周期 info.LastUsed = time.Now() // INFO 刷新生命周期
return info.Client, 0 return info, 0
} }
// 空闲数检查 // 空闲数检查
@ -70,7 +71,7 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
code = errcode.ErrGlmNewClientFail code = errcode.ErrGlmNewClientFail
return return
} }
client = preClient.ChatCompletion(g.DefaultModelName) client := preClient.ChatCompletion(g.DefaultModelName)
if mode == GlmModeKnowledgeHub { if mode == GlmModeKnowledgeHub {
client.AddMessage(zhipu.ChatCompletionMessage{ client.AddMessage(zhipu.ChatCompletionMessage{
@ -79,13 +80,28 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
}) })
} }
g.Clients[token] = &ClientInfo{ clientInfo = &ClientInfo{
Client: client, Client: client,
LastUsed: time.Now(), LastUsed: time.Now(),
} }
g.Clients[token] = clientInfo
return return
} }
/**
* @description: 获取并返回 ClientInfo Client code
* @param {string} token
* @param {int} mode
* @return {(*zhipu.ChatCompletionService, int)}
*/
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (*zhipu.ChatCompletionService, int) {
clientInfo, code := g.GetOneGlmClientInfo(token, mode)
if clientInfo == nil || code != 0 {
return nil, code
}
return clientInfo.Client, code
}
// cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client // cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client
func (g *GlmClientHub) cleanupRoutine() { func (g *GlmClientHub) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute) ticker := time.NewTicker(10 * time.Minute)
@ -114,3 +130,8 @@ func (g *GlmClientHub) ReleaseOneGlmClient(token string) {
delete(g.Clients, token) delete(g.Clients, token)
g.MaxIdle += 1 g.MaxIdle += 1
} }
// TAG ClientInfo
func (c *ClientInfo) AddQuery(query string) {
c.UserQuerys = append(c.UserQuerys, query)
}