better Embedding with All query
This commit is contained in:
parent
57f7e3f872
commit
f8a4f1d042
@ -3,12 +3,12 @@ package web
|
|||||||
import (
|
import (
|
||||||
"catface/app/global/errcode"
|
"catface/app/global/errcode"
|
||||||
"catface/app/global/variable"
|
"catface/app/global/variable"
|
||||||
|
"catface/app/model"
|
||||||
"catface/app/model_es"
|
"catface/app/model_es"
|
||||||
"catface/app/service/nlp"
|
"catface/app/service/nlp"
|
||||||
"catface/app/utils/llm_factory"
|
"catface/app/utils/llm_factory"
|
||||||
"catface/app/utils/micro_service"
|
"catface/app/utils/micro_service"
|
||||||
"catface/app/utils/response"
|
"catface/app/utils/response"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 1. query embedding
|
// 1. query embedding
|
||||||
embedding, ok := nlp.GetEmbedding(query)
|
embedding, ok := nlp.GetEmbedding([]string{query})
|
||||||
if !ok {
|
if !ok {
|
||||||
code := errcode.ErrPythonService
|
code := errcode.ErrPythonService
|
||||||
response.Fail(context, code, errcode.ErrMsg[code], "")
|
response.Fail(context, code, errcode.ErrMsg[code], "")
|
||||||
@ -138,17 +138,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。
|
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||||
tokenMsg := struct {
|
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
||||||
Type string `json:"type"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
}{
|
|
||||||
Type: "token",
|
|
||||||
Token: token,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenBytes, _ := json.Marshal(tokenMsg)
|
err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||||
err := ws.WriteMessage(websocket.TextMessage, tokenBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -158,7 +151,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
// 0-2. 测试 Python 微服务是否启动
|
// 0-2. 测试 Python 微服务是否启动
|
||||||
if !micro_service.TestLinkPythonService() {
|
if !micro_service.TestLinkPythonService() {
|
||||||
code := errcode.ErrPythonServierDown
|
code := errcode.ErrPythonServierDown
|
||||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -166,10 +159,10 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
|
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
|
||||||
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
|
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
|
||||||
if ercode != 0 {
|
if ercode != 0 {
|
||||||
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
|
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
|
||||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode]))
|
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -177,10 +170,11 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 1. query embedding
|
// 1. query embedding
|
||||||
embedding, ok := nlp.GetEmbedding(query)
|
clientInfo.AddQuery(query)
|
||||||
|
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
|
||||||
if !ok {
|
if !ok {
|
||||||
code := errcode.ErrPythonServierDown
|
code := errcode.ErrPythonServierDown
|
||||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -193,7 +187,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
code := errcode.ErrNoDocFound
|
code := errcode.ErrNoDocFound
|
||||||
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
|
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -205,7 +199,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
ch := make(chan string) // TIP 建立通道。
|
ch := make(chan string) // TIP 建立通道。
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
|
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, clientInfo.Client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -219,7 +213,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
||||||
err := ws.WriteMessage(websocket.TextMessage, []byte(c))
|
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
24
app/model/rag_websocket_result.go
Normal file
24
app/model/rag_websocket_result.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
|
||||||
|
if t == "" {
|
||||||
|
t = "chat"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NlpWebSocketResult{
|
||||||
|
Type: t,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NlpWebSocketResult struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data any `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NlpWebSocketResult) JsonMarshal() []byte {
|
||||||
|
data, _ := json.Marshal(n)
|
||||||
|
return data
|
||||||
|
}
|
@ -14,7 +14,7 @@ type EmbeddingRes struct {
|
|||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEmbedding(text string) ([]float64, bool) {
|
func GetEmbedding(text []string) ([]float64, bool) {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"text": text,
|
"text": text,
|
||||||
}
|
}
|
||||||
|
@ -20,8 +20,9 @@ type GlmClientHub struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClientInfo struct {
|
type ClientInfo struct {
|
||||||
Client *zhipu.ChatCompletionService
|
Client *zhipu.ChatCompletionService
|
||||||
LastUsed time.Time
|
UserQuerys []string
|
||||||
|
LastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub {
|
func InitGlmClientHub(maxIdle, maxActive, lifetime int, apiKey, defaultModelName, initPrompt string) *GlmClientHub {
|
||||||
@ -50,10 +51,10 @@ const (
|
|||||||
* @param {string} token: // TODO 如何在 token 中保存信息?
|
* @param {string} token: // TODO 如何在 token 中保存信息?
|
||||||
* @return {*}
|
* @return {*}
|
||||||
*/
|
*/
|
||||||
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.ChatCompletionService, code int) {
|
func (g *GlmClientHub) GetOneGlmClientInfo(token string, mode int) (clientInfo *ClientInfo, code int) {
|
||||||
if info, ok := g.Clients[token]; ok {
|
if info, ok := g.Clients[token]; ok {
|
||||||
info.LastUsed = time.Now() // INFO 刷新生命周期
|
info.LastUsed = time.Now() // INFO 刷新生命周期
|
||||||
return info.Client, 0
|
return info, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// 空闲数检查
|
// 空闲数检查
|
||||||
@ -70,7 +71,7 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
|
|||||||
code = errcode.ErrGlmNewClientFail
|
code = errcode.ErrGlmNewClientFail
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
client = preClient.ChatCompletion(g.DefaultModelName)
|
client := preClient.ChatCompletion(g.DefaultModelName)
|
||||||
|
|
||||||
if mode == GlmModeKnowledgeHub {
|
if mode == GlmModeKnowledgeHub {
|
||||||
client.AddMessage(zhipu.ChatCompletionMessage{
|
client.AddMessage(zhipu.ChatCompletionMessage{
|
||||||
@ -79,13 +80,28 @@ func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (client *zhipu.Ch
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Clients[token] = &ClientInfo{
|
clientInfo = &ClientInfo{
|
||||||
Client: client,
|
Client: client,
|
||||||
LastUsed: time.Now(),
|
LastUsed: time.Now(),
|
||||||
}
|
}
|
||||||
|
g.Clients[token] = clientInfo
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description: 获取并返回 ClientInfo 的 Client 和 code。
|
||||||
|
* @param {string} token
|
||||||
|
* @param {int} mode
|
||||||
|
* @return {(*zhipu.ChatCompletionService, int)}
|
||||||
|
*/
|
||||||
|
func (g *GlmClientHub) GetOneGlmClient(token string, mode int) (*zhipu.ChatCompletionService, int) {
|
||||||
|
clientInfo, code := g.GetOneGlmClientInfo(token, mode)
|
||||||
|
if clientInfo == nil || code != 0 {
|
||||||
|
return nil, code
|
||||||
|
}
|
||||||
|
return clientInfo.Client, code
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client
|
// cleanupRoutine 定期检查并清理超过 1 小时未使用的 Client
|
||||||
func (g *GlmClientHub) cleanupRoutine() {
|
func (g *GlmClientHub) cleanupRoutine() {
|
||||||
ticker := time.NewTicker(10 * time.Minute)
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
@ -114,3 +130,8 @@ func (g *GlmClientHub) ReleaseOneGlmClient(token string) {
|
|||||||
delete(g.Clients, token)
|
delete(g.Clients, token)
|
||||||
g.MaxIdle += 1
|
g.MaxIdle += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TAG ClientInfo
|
||||||
|
func (c *ClientInfo) AddQuery(query string) {
|
||||||
|
c.UserQuerys = append(c.UserQuerys, query)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user