235 lines
6.2 KiB
Go
Raw Normal View History

2024-11-16 02:38:34 +08:00
package web
import (
"catface/app/global/errcode"
"catface/app/global/variable"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/utils/llm_factory"
"catface/app/utils/micro_service"
2024-11-16 02:38:34 +08:00
"catface/app/utils/response"
"encoding/json"
2024-11-16 14:00:57 +08:00
"io"
2024-11-16 18:18:07 +08:00
"net/http"
2024-11-16 02:38:34 +08:00
"github.com/gin-gonic/gin"
2024-11-16 18:18:07 +08:00
"github.com/gorilla/websocket"
2024-11-16 02:38:34 +08:00
"go.uber.org/zap"
)
type Rag struct {
}
// v1 Http-POST 版本; chat 需要不使用 ch 的版本。
2024-11-16 14:00:57 +08:00
// func (r *Rag) Chat(context *gin.Context) {
// // 1. query embedding
// query := context.GetString(consts.ValidatorPrefix + "query")
// embedding, ok := nlp.GetEmbedding(query)
// if !ok {
// code := errcode.ErrPythonService
// response.Fail(context, code, errcode.ErrMsg[code], "")
// return
// }
// // 2. ES TopK
// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
// if err != nil || len(docs) == 0 {
// variable.ZapLog.Error("ES TopK error", zap.Error(err))
// code := errcode.ErrNoDocFound
// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
// }
// // 3. LLM answer
// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
// response.Success(context, consts.CurdStatusOkMsg, gin.H{
// "answer": answer,
// })
// } else {
// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
// }
// }
func (r *Rag) ChatSSE(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
// 0-1. 测试 python
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 0-2. 获取一个 GLM Client
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
response.Fail(context, ercode, errcode.ErrMsg[ercode], errcode.ErrMsgForUser[ercode])
return
}
2024-11-16 14:00:57 +08:00
2024-11-16 02:38:34 +08:00
// 1. query embedding
embedding, ok := nlp.GetEmbedding(query)
if !ok {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
}
2024-11-16 14:00:57 +08:00
// UPDATE
closeEventFromVue := context.Request.Context().Done()
ch := make(chan string) // TIP 建立通道。
2024-11-16 02:38:34 +08:00
// 3. LLM answer
2024-11-16 14:00:57 +08:00
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
2024-11-16 14:00:57 +08:00
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch)
}()
context.Stream(func(w io.Writer) bool {
select {
case c, ok := <-ch:
if !ok {
return false
}
context.SSEvent("chat", c)
return true
case <-closeEventFromVue:
return false
}
})
2024-11-16 02:38:34 +08:00
}
2024-11-16 18:18:07 +08:00
var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // info 在生产环境中可能需要更安全的检查
},
}
func (r *Rag) ChatWebSocket(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
2024-11-16 18:18:07 +08:00
// 0-1. 协议升级
2024-11-16 18:18:07 +08:00
ws, err := upgrader.Upgrade(context.Writer, context.Request, nil)
if err != nil {
variable.ZapLog.Error("OnOpen error", zap.Error(err))
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
return
}
defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。
tokenMsg := struct {
Type string `json:"type"`
Token string `json:"token"`
}{
Type: "token",
Token: token,
}
tokenBytes, _ := json.Marshal(tokenMsg)
err := ws.WriteMessage(websocket.TextMessage, tokenBytes)
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
}
ws.Close()
}()
// 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
2024-11-16 18:18:07 +08:00
// 1. query embedding
embedding, ok := nlp.GetEmbedding(query)
if !ok {
2024-11-18 00:39:36 +08:00
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
2024-11-16 18:18:07 +08:00
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
2024-11-16 18:18:07 +08:00
}
// 3.
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
ch := make(chan string) // TIP 建立通道。
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
2024-11-16 18:18:07 +08:00
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch) // 这里 close使得下方 for 结束。
}()
for {
select {
case c, ok := <-ch:
if !ok {
return
}
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
err := ws.WriteMessage(websocket.TextMessage, []byte(c))
if err != nil {
return
}
case <-closeEventFromVue:
return
}
}
}
2024-11-16 02:38:34 +08:00
func (r *Rag) HelpDetectCat(context *gin.Context) {
// TODO
}