235 lines
6.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package web
import (
"catface/app/global/errcode"
"catface/app/global/variable"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/utils/llm_factory"
"catface/app/utils/micro_service"
"catface/app/utils/response"
"encoding/json"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
type Rag struct {
}
// v1 Http-POST 版本; chat 需要不使用 ch 的版本。
// func (r *Rag) Chat(context *gin.Context) {
// // 1. query embedding
// query := context.GetString(consts.ValidatorPrefix + "query")
// embedding, ok := nlp.GetEmbedding(query)
// if !ok {
// code := errcode.ErrPythonService
// response.Fail(context, code, errcode.ErrMsg[code], "")
// return
// }
// // 2. ES TopK
// docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
// if err != nil || len(docs) == 0 {
// variable.ZapLog.Error("ES TopK error", zap.Error(err))
// code := errcode.ErrNoDocFound
// response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
// }
// // 3. LLM answer
// if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
// response.Success(context, consts.CurdStatusOkMsg, gin.H{
// "answer": answer,
// })
// } else {
// response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
// }
// }
func (r *Rag) ChatSSE(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
// 0-1. 测试 python
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 0-2. 获取一个 GLM Client
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
response.Fail(context, ercode, errcode.ErrMsg[ercode], errcode.ErrMsgForUser[ercode])
return
}
// 1. query embedding
embedding, ok := nlp.GetEmbedding(query)
if !ok {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
}
// UPDATE
closeEventFromVue := context.Request.Context().Done()
ch := make(chan string) // TIP 建立通道。
// 3. LLM answer
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch)
}()
context.Stream(func(w io.Writer) bool {
select {
case c, ok := <-ch:
if !ok {
return false
}
context.SSEvent("chat", c)
return true
case <-closeEventFromVue:
return false
}
})
}
var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // info 在生产环境中可能需要更安全的检查
},
}
func (r *Rag) ChatWebSocket(context *gin.Context) {
query := context.Query("query")
token := context.Query("token")
if token == "" {
token = variable.SnowFlake.GetIdAsString()
}
// 0-1. 协议升级
ws, err := upgrader.Upgrade(context.Writer, context.Request, nil)
if err != nil {
variable.ZapLog.Error("OnOpen error", zap.Error(err))
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
return
}
defer func() { // UPDATE 临时方案,之后考虑结合 jwt 维护的 token 处理。
tokenMsg := struct {
Type string `json:"type"`
Token string `json:"token"`
}{
Type: "token",
Token: token,
}
tokenBytes, _ := json.Marshal(tokenMsg)
err := ws.WriteMessage(websocket.TextMessage, tokenBytes)
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
}
ws.Close()
}()
// 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 0-3. 从 GLM_HUB 中获取一个可用的 glm client;
client, ercode := variable.GlmClientHub.GetOneGlmClient(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[ercode]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 1. query embedding
embedding, ok := nlp.GetEmbedding(query)
if !ok {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
err := ws.WriteMessage(websocket.TextMessage, []byte(errcode.ErrMsgForUser[code]))
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
return
}
// 3.
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
ch := make(chan string) // TIP 建立通道。
go func() {
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch, client)
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
close(ch) // 这里 close使得下方 for 结束。
}()
for {
select {
case c, ok := <-ch:
if !ok {
return
}
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
err := ws.WriteMessage(websocket.TextMessage, []byte(c))
if err != nil {
return
}
case <-closeEventFromVue:
return
}
}
}
func (r *Rag) HelpDetectCat(context *gin.Context) {
// TODO
}