better Embedding with All query
This commit is contained in:
parent
57f7e3f872
commit
f8a4f1d042
@ -3,12 +3,12 @@ package web
|
||||
import (
|
||||
"catface/app/global/errcode"
|
||||
"catface/app/global/variable"
|
||||
"catface/app/model"
|
||||
"catface/app/model_es"
|
||||
"catface/app/service/nlp"
|
||||
"catface/app/utils/llm_factory"
|
||||
"catface/app/utils/micro_service"
|
||||
"catface/app/utils/response"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@ -72,7 +72,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
||||
}
|
||||
|
||||
// 1. query embedding
|
||||
embedding, ok := nlp.GetEmbedding(query)
|
||||
embedding, ok := nlp.GetEmbedding([]string{query})
|
||||
if !ok {
|
||||
code := errcode.ErrPythonService
|
||||
response.Fail(context, code, errcode.ErrMsg[code], "")
|
||||
@ -138,17 +138,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
||||
return
|
||||
}
|
||||
defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。
|
||||
tokenMsg := struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
Type: "token",
|
||||
Token: token,
|
||||
}
|
||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
||||
|
||||
tokenBytes, _ := json.Marshal(tokenMsg)
|
||||
err := ws.WriteMessage(websocket.TextMessage, tokenBytes)
|
||||
err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||
}
|
||||
@ -158,7 +151,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
// 0-2. 测试 Python 微服务是否启动
|
||||
if !micro_service.TestLinkPythonService() {
|
||||
code := errcode.ErrPythonServierDown
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||
}
|
||||
@ -166,10 +159,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
}
|
||||
|
||||
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
|
||||
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
|
||||
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
|
||||
if ercode != 0 {
|
||||
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode]))
|
||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||
}
|
||||
@ -177,10 +170,11 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
}
|
||||
|
||||
// 1. query embedding
|
||||
embedding, ok := nlp.GetEmbedding(query)
|
||||
clientInfo.AddQuery(query)
|
||||
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
|
||||
if !ok {
|
||||
code := errcode.ErrPythonServierDown
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||
}
|
||||
@ -193,7 +187,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||
|
||||
code := errcode.ErrNoDocFound
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||
}
|
||||
@ -205,7 +199,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
ch := make(chan string) // TIP 建立通道。
|
||||
|
||||
go func() {
|
||||
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
|
||||
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, clientInfo.Client)
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||
}
|
||||
@ -219,7 +213,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
return
|
||||
}
|
||||
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(c))
|
||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
24
app/model/rag_websocket_result.go
Normal file
24
app/model/rag_websocket_result.go
Normal file
@ -0,0 +1,24 @@
|
||||
package model
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
|
||||
if t == "" {
|
||||
t = "chat"
|
||||
}
|
||||
|
||||
return &NlpWebSocketResult{
|
||||
Type: t,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
type NlpWebSocketResult struct {
|
||||
Type string `json:"type"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
func (n *NlpWebSocketResult) JsonMarshal() []byte {
|
||||
data, _ := json.Marshal(n)
|
||||
return data
|
||||
}
|
@ -14,7 +14,7 @@ type EmbeddingRes struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func GetEmbedding(text string) ([]float64, bool) {
|
||||
func GetEmbedding(text []string) ([]float64, bool) {
|
||||
body := map[string]interface{}{
|
||||
"text": text,
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ type GlmClientHub struct {
|
||||
|
||||
type ClientInfo struct {
|
||||
Client *zhipu.ChatCompletionService
|
||||
UserQuerys []string
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
@ -50,10 +51,10 @@ const (
|
||||
* @param {string} token: // TODO 如何在 token 中保存信息?
|
||||
* @return {*}
|
||||
*/
|
||||
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.ChatCompletionService, code int) {
|
||||
func (g *GlmClientHub) GetOneGlmClientInfo(token string, mode int) (clientInfo *ClientInfo, code int) {
|
||||
if info, ok := g.Clients[token]; ok {
|
||||
info.LastUsed = time.Now() // INFO 刷新生命周期
|
||||
return info.Client, 0
|
||||
return info, 0
|
||||
}
|
||||
|
||||
// 空闲数检查
|
||||
@ -70,7 +71,7 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
|
||||
code = errcode.ErrGlmNewClientFail
|
||||
return
|
||||
}
|
||||
client = preClient.ChatCompletion(g.DefaultModelName)
|
||||
client := preClient.ChatCompletion(g.DefaultModelName)
|
||||
|
||||
if mode == GlmModeKnowledgeHub {
|
||||
client.AddMessage(zhipu.ChatCompletionMessage{
|
||||
@ -79,13 +80,28 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
|
||||
})
|
||||
}
|
||||
|
||||
g.Clients[token] = &ClientInfo{
|
||||
clientInfo = &ClientInfo{
|
||||
Client: client,
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
g.Clients[token] = clientInfo
|
||||
return
|
||||
}
|
||||
|
||||
/**
|
||||
* @description: 获取并返回 ClientInfo 的 Client 和 code。
|
||||
* @param {string} token
|
||||
* @param {int} mode
|
||||
* @return {(*zhipu.ChatCompletionService, int)}
|
||||
*/
|
||||
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (*zhipu.ChatCompletionService, int) {
|
||||
clientInfo, code := g.GetOneGlmClientInfo(token, mode)
|
||||
if clientInfo == nil || code != 0 {
|
||||
return nil, code
|
||||
}
|
||||
return clientInfo.Client, code
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client
|
||||
func (g *GlmClientHub) cleanupRoutine() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
@ -114,3 +130,8 @@ func (g *GlmClientHub) ReleaseOneGlmClient(token string) {
|
||||
delete(g.Clients, token)
|
||||
g.MaxIdle += 1
|
||||
}
|
||||
|
||||
// TAG ClientInfo
|
||||
func (c *ClientInfo) AddQuery(query string) {
|
||||
c.UserQuerys = append(c.UserQuerys, query)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user