refactor(app): 重构 WebSocket 聊天流程并优化文档查询功能

- 调整了 ES TopK 查询逻辑,增加了文档类型筛选
- 优化了 WebSocket 的关闭流程,增加了文档和 token 信息的发送
- 新增了 Doc 模型的 ShowById 和 ShowByIds 方法,用于查询特定文档
- 更新了 prompts.yml 文件中的提示语,将"上下文"改为"知识库"
This commit is contained in:
Havoc412 2024-11-19 13:06:39 +08:00
parent 3b719c3add
commit 91073fdf7f
5 changed files with 118 additions and 14 deletions

View File

@ -6,6 +6,7 @@ import (
"catface/app/model"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/service/rag/curd"
"catface/app/utils/llm_factory"
"catface/app/utils/micro_service"
"catface/app/utils/response"
@ -138,15 +139,6 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
return
}
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
tokenMsg := model.CreateNlpWebSocketResult("token", token)
err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
}
ws.Close()
}()
// 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() {
@ -181,8 +173,8 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
// 2. ES TopK // TODO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
@ -194,6 +186,24 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
return
}
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
// 0. 传递参考资料的信息
docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
}
// 1. 传递 token 信息; // UPDATE 临时方案
tokenMsg := model.CreateNlpWebSocketResult("token", token)
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
}
ws.Close()
}()
// 3.
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
ch := make(chan string) // TIP 建立通道。

View File

@ -35,3 +35,30 @@ func (d *Doc) InsertDocumentData(c *gin.Context) (int64, bool) {
}
return 0, false
}
func (d *Doc) ShowById(id int64, attrs ...string) *Doc {
var temp Doc
db := d.DB.Table(d.TableName()).Where("id = ?", id)
if len(attrs) > 0 {
db.Select(attrs)
}
err := db.First(&temp)
if err.Error != nil {
variable.ZapLog.Error("Doc ShowById Error", zap.Error(err.Error))
}
return &temp
}
func (d *Doc) ShowByIds(ids []int64, attrs ...string) (temp []Doc) {
db := d.DB.Table(d.TableName()).Where("id in (?)", ids)
if len(attrs) > 0 {
db.Select(attrs)
}
err := db.Find(&temp)
if err.Error != nil {
variable.ZapLog.Error("Doc ShowByIds Error", zap.Error(err.Error))
}
return
}

23
app/model_res/doc.go Normal file
View File

@ -0,0 +1,23 @@
package model_res
import (
"catface/app/model"
"catface/app/model_es"
)
// BUG 存在 依賴循環
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
return &DocResult{
Type: "doc",
Id: doc.Id,
Name: doc.Name,
Content: doc_es.Content,
}
}
type DocResult struct {
Type string `json:"type"`
Id int64 `json:"id"`
Name string `json:"name"`
Content string `json:"content"`
}

View File

@ -0,0 +1,44 @@
package curd
import (
"catface/app/model"
"catface/app/model_es"
"catface/app/model_res"
)
func CreateDocCurdFactory() *DocCurd {
return &DocCurd{
doc: model.CreateDocFactory(""),
doc_es: model_es.CreateDocESFactory()}
}
type DocCurd struct { // 组合方法的使用
doc *model.Doc
doc_es *model_es.Doc
}
func (d *DocCurd) TopK(embedding []float64, k int) (temp []model_res.DocResult, err error) {
// ESTopK
docs_es, err := d.doc_es.TopK(embedding, k)
if err != nil {
return
}
// MySQL补充基本信息
var ids []int64
for _, doc := range docs_es {
ids = append(ids, doc.Id)
}
docs := d.doc.ShowByIds(ids, "id", "name")
// 装载
for _, doc := range docs {
for _, doc_es := range docs_es {
if doc.Id == doc_es.Id {
temp = append(temp, *model_res.NewDocResult(&doc, &doc_es))
}
}
}
return
}

View File

@ -12,10 +12,10 @@ Prompt:
Title: "请根据以下长文本生成一个合适的标题不需要书名号长度10字内"
KnoledgeRAG: "使用以上下文来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
KnoledgeRAG: "使用以知识库来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
问题: {question}
可参考的上下文
可参考的知识库
···
{context}
···
如果给定的上下文无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"