refactor(app): 重构 WebSocket 聊天流程并优化文档查询功能
- 调整了 ES TopK 查询逻辑,增加了文档类型筛选 - 优化了 WebSocket 的关闭流程,增加了文档和 token 信息的发送 - 新增了 Doc 模型的 ShowById 和 ShowByIds 方法,用于查询特定文档 - 更新了 prompts.yml 文件中的提示语,将"上下文"改为"知识库"
This commit is contained in:
parent
3b719c3add
commit
91073fdf7f
@ -6,6 +6,7 @@ import (
|
||||
"catface/app/model"
|
||||
"catface/app/model_es"
|
||||
"catface/app/service/nlp"
|
||||
"catface/app/service/rag/curd"
|
||||
"catface/app/utils/llm_factory"
|
||||
"catface/app/utils/micro_service"
|
||||
"catface/app/utils/response"
|
||||
@ -138,15 +139,6 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
response.Fail(context, errcode.ErrWebsocketUpgradeFail, errcode.ErrMsg[errcode.ErrWebsocketUpgradeFail], "")
|
||||
return
|
||||
}
|
||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
||||
|
||||
err := ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||
}
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
// 0-2. 测试 Python 微服务是否启动
|
||||
if !micro_service.TestLinkPythonService() {
|
||||
@ -181,8 +173,8 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. ES TopK
|
||||
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
||||
// 2. ES TopK // TODO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
||||
docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
|
||||
if err != nil || len(docs) == 0 {
|
||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||
|
||||
@ -194,6 +186,24 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||
// 0. 传递参考资料的信息
|
||||
docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
|
||||
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
||||
}
|
||||
|
||||
// 1. 传递 token 信息; // UPDATE 临时方案
|
||||
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
||||
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||
}
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
// 3.
|
||||
closeEventFromVue := context.Request.Context().Done() // 接收前端传来的中断信号。
|
||||
ch := make(chan string) // TIP 建立通道。
|
||||
|
@ -35,3 +35,30 @@ func (d *Doc) InsertDocumentData(c *gin.Context) (int64, bool) {
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (d *Doc) ShowById(id int64, attrs ...string) *Doc {
|
||||
var temp Doc
|
||||
db := d.DB.Table(d.TableName()).Where("id = ?", id)
|
||||
|
||||
if len(attrs) > 0 {
|
||||
db.Select(attrs)
|
||||
}
|
||||
|
||||
err := db.First(&temp)
|
||||
if err.Error != nil {
|
||||
variable.ZapLog.Error("Doc ShowById Error", zap.Error(err.Error))
|
||||
}
|
||||
return &temp
|
||||
}
|
||||
|
||||
func (d *Doc) ShowByIds(ids []int64, attrs ...string) (temp []Doc) {
|
||||
db := d.DB.Table(d.TableName()).Where("id in (?)", ids)
|
||||
if len(attrs) > 0 {
|
||||
db.Select(attrs)
|
||||
}
|
||||
err := db.Find(&temp)
|
||||
if err.Error != nil {
|
||||
variable.ZapLog.Error("Doc ShowByIds Error", zap.Error(err.Error))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
23
app/model_res/doc.go
Normal file
23
app/model_res/doc.go
Normal file
@ -0,0 +1,23 @@
|
||||
package model_res
|
||||
|
||||
import (
|
||||
"catface/app/model"
|
||||
"catface/app/model_es"
|
||||
)
|
||||
|
||||
// BUG 存在 依賴循環
|
||||
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
||||
return &DocResult{
|
||||
Type: "doc",
|
||||
Id: doc.Id,
|
||||
Name: doc.Name,
|
||||
Content: doc_es.Content,
|
||||
}
|
||||
}
|
||||
|
||||
type DocResult struct {
|
||||
Type string `json:"type"`
|
||||
Id int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
}
|
44
app/service/rag/curd/doc_curd.go
Normal file
44
app/service/rag/curd/doc_curd.go
Normal file
@ -0,0 +1,44 @@
|
||||
package curd
|
||||
|
||||
import (
|
||||
"catface/app/model"
|
||||
"catface/app/model_es"
|
||||
"catface/app/model_res"
|
||||
)
|
||||
|
||||
func CreateDocCurdFactory() *DocCurd {
|
||||
return &DocCurd{
|
||||
doc: model.CreateDocFactory(""),
|
||||
doc_es: model_es.CreateDocESFactory()}
|
||||
}
|
||||
|
||||
type DocCurd struct { // 组合方法的使用
|
||||
doc *model.Doc
|
||||
doc_es *model_es.Doc
|
||||
}
|
||||
|
||||
func (d *DocCurd) TopK(embedding []float64, k int) (temp []model_res.DocResult, err error) {
|
||||
// ES:TopK
|
||||
docs_es, err := d.doc_es.TopK(embedding, k)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// MySQL:补充基本信息
|
||||
var ids []int64
|
||||
for _, doc := range docs_es {
|
||||
ids = append(ids, doc.Id)
|
||||
}
|
||||
docs := d.doc.ShowByIds(ids, "id", "name")
|
||||
|
||||
// 装载
|
||||
for _, doc := range docs {
|
||||
for _, doc_es := range docs_es {
|
||||
if doc.Id == doc_es.Id {
|
||||
temp = append(temp, *model_res.NewDocResult(&doc, &doc_es))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -12,10 +12,10 @@ Prompt:
|
||||
|
||||
Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:"
|
||||
|
||||
KnoledgeRAG: "使用以上下文来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
|
||||
KnoledgeRAG: "使用以知识库来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
可参考的知识库:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
||||
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
Loading…
x
Reference in New Issue
Block a user