refactor(app): 重构 WebSocket 聊天流程并优化文档查询功能
- 调整了 ES TopK 查询逻辑,增加了文档类型筛选 - 优化了 WebSocket 的关闭流程,增加了文档和 token 信息的发送 - 新增了 Doc 模型的 ShowById 和 ShowByIds 方法,用于查询特定文档 - 更新了 prompts.yml 文件中的提示语,将"上下文"改为"知识库"
This commit is contained in:
parent
3b719c3add
commit
91073fdf7f
@ -6,6 +6,7 @@ import (
|
|||||||
"catface/app/model"
|
"catface/app/model"
|
||||||
"catface/app/model_es"
|
"catface/app/model_es"
|
||||||
"catface/app/service/nlp"
|
"catface/app/service/nlp"
|
||||||
|
"catface/app/service/rag/curd"
|
||||||
"catface/app/utils/llm_factory"
|
"catface/app/utils/llm_factory"
|
||||||
"catface/app/utils/micro_service"
|
"catface/app/utils/micro_service"
|
||||||
"catface/app/utils/response"
|
"catface/app/utils/response"
|
||||||
@ -138,15 +139,6 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
|
||||||
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
|
||||||
|
|
||||||
err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
|
||||||
if err != nil {
|
|
||||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
|
||||||
}
|
|
||||||
ws.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 0-2. 测试 Python 微服务是否启动
|
// 0-2. 测试 Python 微服务是否启动
|
||||||
if !micro_service.TestLinkPythonService() {
|
if !micro_service.TestLinkPythonService() {
|
||||||
@ -181,8 +173,8 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. ES TopK
|
// 2. ES TopK // TODO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
||||||
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
|
||||||
if err != nil || len(docs) == 0 {
|
if err != nil || len(docs) == 0 {
|
||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
@ -194,6 +186,24 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
||||||
|
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||||
|
// 0. 传递参考资料的信息
|
||||||
|
docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
|
||||||
|
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
||||||
|
if err != nil {
|
||||||
|
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 传递 token 信息; // UPDATE 临时方案
|
||||||
|
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
||||||
|
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||||
|
if err != nil {
|
||||||
|
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||||
|
}
|
||||||
|
ws.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
// 3.
|
// 3.
|
||||||
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
|
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
|
||||||
ch := make(chan string) // TIP 建立通道。
|
ch := make(chan string) // TIP 建立通道。
|
||||||
|
@ -35,3 +35,30 @@ func (d *Doc) InsertDocumentData(c *gin.Context) (int64, bool) {
|
|||||||
}
|
}
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Doc) ShowById(id int64, attrs ...string) *Doc {
|
||||||
|
var temp Doc
|
||||||
|
db := d.DB.Table(d.TableName()).Where("id = ?", id)
|
||||||
|
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
db.Select(attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.First(&temp)
|
||||||
|
if err.Error != nil {
|
||||||
|
variable.ZapLog.Error("Doc ShowById Error", zap.Error(err.Error))
|
||||||
|
}
|
||||||
|
return &temp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Doc) ShowByIds(ids []int64, attrs ...string) (temp []Doc) {
|
||||||
|
db := d.DB.Table(d.TableName()).Where("id in (?)", ids)
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
db.Select(attrs)
|
||||||
|
}
|
||||||
|
err := db.Find(&temp)
|
||||||
|
if err.Error != nil {
|
||||||
|
variable.ZapLog.Error("Doc ShowByIds Error", zap.Error(err.Error))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
23
app/model_res/doc.go
Normal file
23
app/model_res/doc.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package model_res
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/model"
|
||||||
|
"catface/app/model_es"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BUG 存在 依賴循環
|
||||||
|
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
||||||
|
return &DocResult{
|
||||||
|
Type: "doc",
|
||||||
|
Id: doc.Id,
|
||||||
|
Name: doc.Name,
|
||||||
|
Content: doc_es.Content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DocResult struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Id int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
44
app/service/rag/curd/doc_curd.go
Normal file
44
app/service/rag/curd/doc_curd.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package curd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/model"
|
||||||
|
"catface/app/model_es"
|
||||||
|
"catface/app/model_res"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateDocCurdFactory() *DocCurd {
|
||||||
|
return &DocCurd{
|
||||||
|
doc: model.CreateDocFactory(""),
|
||||||
|
doc_es: model_es.CreateDocESFactory()}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DocCurd struct { // 组合方法的使用
|
||||||
|
doc *model.Doc
|
||||||
|
doc_es *model_es.Doc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DocCurd) TopK(embedding []float64, k int) (temp []model_res.DocResult, err error) {
|
||||||
|
// ES:TopK
|
||||||
|
docs_es, err := d.doc_es.TopK(embedding, k)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MySQL:补充基本信息
|
||||||
|
var ids []int64
|
||||||
|
for _, doc := range docs_es {
|
||||||
|
ids = append(ids, doc.Id)
|
||||||
|
}
|
||||||
|
docs := d.doc.ShowByIds(ids, "id", "name")
|
||||||
|
|
||||||
|
// 装载
|
||||||
|
for _, doc := range docs {
|
||||||
|
for _, doc_es := range docs_es {
|
||||||
|
if doc.Id == doc_es.Id {
|
||||||
|
temp = append(temp, *model_res.NewDocResult(&doc, &doc_es))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -12,10 +12,10 @@ Prompt:
|
|||||||
|
|
||||||
Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:"
|
Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:"
|
||||||
|
|
||||||
KnoledgeRAG: "使用以上下文来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
|
KnoledgeRAG: "使用以知识库来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
|
||||||
问题: {question}
|
问题: {question}
|
||||||
可参考的上下文:
|
可参考的知识库:
|
||||||
···
|
···
|
||||||
{context}
|
{context}
|
||||||
···
|
···
|
||||||
如果给定的上下文无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
Loading…
x
Reference in New Issue
Block a user