🎏 finish RAG
This commit is contained in:
parent
2af03cbf13
commit
3413d52316
@ -8,6 +8,7 @@ const (
|
|||||||
ErrNlp
|
ErrNlp
|
||||||
ErrKnowledge
|
ErrKnowledge
|
||||||
ErrSubService
|
ErrSubService
|
||||||
|
ErrWebSocket
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -15,6 +15,7 @@ func init() {
|
|||||||
NlpMsgInit(ErrMsg)
|
NlpMsgInit(ErrMsg)
|
||||||
KnowledgeMsgInit(ErrMsg)
|
KnowledgeMsgInit(ErrMsg)
|
||||||
SubServiceMsgInit(ErrMsg)
|
SubServiceMsgInit(ErrMsg)
|
||||||
|
WsMsgInit(ErrMsg)
|
||||||
|
|
||||||
// INGO
|
// INGO
|
||||||
ErrMsgForUser = make(msg)
|
ErrMsgForUser = make(msg)
|
||||||
|
9
app/global/errcode/websocket.go
Normal file
9
app/global/errcode/websocket.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package errcode
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrWebsocketUpgradeFail = ErrWebSocket + iota
|
||||||
|
)
|
||||||
|
|
||||||
|
func WsMsgInit(m msg) {
|
||||||
|
m[ErrWebsocketUpgradeFail] = "websocket升级失败"
|
||||||
|
}
|
@ -7,8 +7,10 @@ import (
|
|||||||
"catface/app/service/nlp"
|
"catface/app/service/nlp"
|
||||||
"catface/app/utils/response"
|
"catface/app/utils/response"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,6 +94,72 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{ // TEST 测试,先写一个裸的 wss
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true // info 在生产环境中可能需要更安全的检查
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||||
|
query := context.Query("query")
|
||||||
|
|
||||||
|
// 0. 协议升级
|
||||||
|
ws, err := upgrader.Upgrade(context.Writer, context.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
variable.ZapLog.Error("OnOpen error", zap.Error(err))
|
||||||
|
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
// 1. query embedding
|
||||||
|
embedding, ok := nlp.GetEmbedding(query)
|
||||||
|
if !ok {
|
||||||
|
code := errcode.ErrPythonService
|
||||||
|
response.Fail(context, code, errcode.ErrMsg[code], "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. ES TopK
|
||||||
|
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
||||||
|
if err != nil || len(docs) == 0 {
|
||||||
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
|
code := errcode.ErrNoDocFound
|
||||||
|
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3.
|
||||||
|
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
|
||||||
|
ch := make(chan string) // TIP 建立通道。
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := nlp.ChatKnoledgeRAG(docs[0].Content, query, ch)
|
||||||
|
if err != nil {
|
||||||
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
|
}
|
||||||
|
close(ch) // 这里 close,使得下方 for 结束。
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c, ok := <-ch:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
||||||
|
err := ws.WriteMessage(websocket.TextMessage, []byte(c))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-closeEventFromVue:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ func (c Chat) CheckParams(context *gin.Context) {
|
|||||||
if extraAddBindDataContext == nil {
|
if extraAddBindDataContext == nil {
|
||||||
response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "")
|
response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "")
|
||||||
} else {
|
} else {
|
||||||
(&web.Rag{}).ChatSSE(extraAddBindDataContext)
|
(&web.Rag{}).ChatWebSocket(extraAddBindDataContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,8 @@ func ChatKnoledgeRAG(doc, query string, ch chan<- string) error {
|
|||||||
message = strings.Replace(message, "{context}", doc, -1)
|
message = strings.Replace(message, "{context}", doc, -1)
|
||||||
|
|
||||||
// 调用聊天接口
|
// 调用聊天接口
|
||||||
err := glm.ChatStream(message, ch)
|
// err := glm.ChatStream(message, ch)
|
||||||
|
err := glm.BufferedChatStream(message, ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("调用聊天接口失败: %w", err)
|
return fmt.Errorf("调用聊天接口失败: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,10 @@ import (
|
|||||||
"catface/app/global/variable"
|
"catface/app/global/variable"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/yankeguo/zhipu"
|
"github.com/yankeguo/zhipu"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatWithGLM 封装了与GLM模型进行对话的逻辑
|
// ChatWithGLM 封装了与GLM模型进行对话的逻辑
|
||||||
@ -46,3 +47,42 @@ func ChatStream(message string, ch chan<- string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 带缓冲机制的 ChatStream;计数 & 计时 双判定。
|
||||||
|
func BufferedChatStream(message string, ch chan<- string) error {
|
||||||
|
bufferedCh := make(chan string) // 带缓冲的通道,缓冲大小为10
|
||||||
|
timer := time.NewTimer(500 * time.Millisecond) // 定时器,500毫秒
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := ChatStream(message, bufferedCh)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(bufferedCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var buffer strings.Builder
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c, ok := <-bufferedCh:
|
||||||
|
if !ok {
|
||||||
|
if buffer.Len() > 0 {
|
||||||
|
ch <- buffer.String()
|
||||||
|
}
|
||||||
|
return nil // 依靠这里停止函数。
|
||||||
|
}
|
||||||
|
buffer.WriteString(c)
|
||||||
|
if buffer.Len() >= 10 {
|
||||||
|
ch <- buffer.String()
|
||||||
|
buffer.Reset()
|
||||||
|
timer.Reset(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
case <-timer.C:
|
||||||
|
if buffer.Len() > 0 {
|
||||||
|
ch <- buffer.String()
|
||||||
|
buffer.Reset()
|
||||||
|
}
|
||||||
|
timer.Reset(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -44,7 +44,7 @@ Logs:
|
|||||||
Compress: false #日志备份时,是否进行压缩
|
Compress: false #日志备份时,是否进行压缩
|
||||||
|
|
||||||
Websocket: #该服务与Http具有相同的ip、端口,因此不需要额外设置端口
|
Websocket: #该服务与Http具有相同的ip、端口,因此不需要额外设置端口
|
||||||
Start: 0 #默认不启动该服务(1=启动;0=不启动)
|
Start: 0 # 默认不启动该服务(1=启动;0=不启动)
|
||||||
WriteReadBufferSize: 20480 # 读写缓冲区分配字节,大概能存储 6800 多一点的文字
|
WriteReadBufferSize: 20480 # 读写缓冲区分配字节,大概能存储 6800 多一点的文字
|
||||||
MaxMessageSize: 65535 # 从消息管道读取消息的最大字节
|
MaxMessageSize: 65535 # 从消息管道读取消息的最大字节
|
||||||
PingPeriod: 20 #心跳包频率,单位:秒
|
PingPeriod: 20 #心跳包频率,单位:秒
|
||||||
|
Loading…
x
Reference in New Issue
Block a user