From 81cd287109d0a65841ec8df2dfcc94c595b7bd75 Mon Sep 17 00:00:00 2001 From: Havoc412 <2993167370@qq.com> Date: Wed, 20 Nov 2024 17:32:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E6=96=B0=E5=A2=9E=20RAG=20?= =?UTF-8?q?=E8=81=8A=E5=A4=A9=E6=A8=A1=E5=BC=8F=E5=92=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20ES=20=E6=9F=A5=E8=AF=A2=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 RAG 聊天模式常量和前端字段设定 - 修改 Encounters Create 方法中的 ES 同步逻辑 - 更新 Rag ChatSSE 和 ChatWebSocket 方法,支持新的聊天模式 - 重构 NlpWebSocketResult 创建函数,使用新增的常量 - 新增 Encounter 的 TopK 方法,用于 ES 向量搜索 - 更新 DocResult 结构,实现 DocInterface 接口 - 修改 prompts.yml,增加 Diary 模式的提示模板 --- app/global/consts/api_mode.go | 7 ++ .../controller/web/encounter_controller.go | 2 +- app/http/controller/web/rag_controller.go | 17 ++-- app/model/rag_websocket_result.go | 7 +- app/model_es/encounter.go | 95 ++++++++++++++----- app/model_res/base_model.go | 13 +++ app/model_res/doc.go | 19 +++- app/model_res/encounter.go | 29 ++++++ app/service/rag/curd/base_curd.go | 37 ++++++++ app/service/rag/curd/encounter_curd.go | 42 ++++++++ config/prompts.yml | 9 +- 11 files changed, 238 insertions(+), 39 deletions(-) create mode 100644 app/model_res/base_model.go create mode 100644 app/model_res/encounter.go create mode 100644 app/service/rag/curd/base_curd.go create mode 100644 app/service/rag/curd/encounter_curd.go diff --git a/app/global/consts/api_mode.go b/app/global/consts/api_mode.go index a68d308..d1a023b 100644 --- a/app/global/consts/api_mode.go +++ b/app/global/consts/api_mode.go @@ -9,3 +9,10 @@ const ( RagChatModeDiary string = "Diary" // 查询路遇资料等 RagChatModeDetect string = "Detect" // 辅助 catface 的辨认功能; ) + +// 前端的字段设定 +const ( + AiMessageTypeText string = "text" + AiMessageTypeDoc string = "doc" + AiMessageTypeToken string = "token" +) diff --git a/app/http/controller/web/encounter_controller.go b/app/http/controller/web/encounter_controller.go index 1b1928c..c5e763c 100644 --- a/app/http/controller/web/encounter_controller.go +++ b/app/http/controller/web/encounter_controller.go @@ -74,7 +74,7 @@ func (e *Encounters) Create(context *gin.Context) { go model.CreateEncounterAnimalLinkFactory("").Insert(encounter.Id, animals_id) // 3. ES speed // TODO 这里如何实现 不同 DB 之间的 “事务” 概念。 - if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 1 { + if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 0 { // TEST 暂时全部数据都同步到 ES,不做 level 过滤。 go model_es.CreateEncounterESFactory(&encounter).InsertDocument() } diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index 3bb7f8d..b683da1 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -71,6 +71,11 @@ func (r *Rag) ChatSSE(context *gin.Context) { query := context.Query("query") token := context.Query("token") + mode := context.Query("mode") + if mode == "" { + mode = consts.RagChatModeKnowledge + } + // 0-1. 测试 python if !micro_service.TestLinkPythonService() { code := errcode.ErrPythonService @@ -98,7 +103,7 @@ func (r *Rag) ChatSSE(context *gin.Context) { } // 2. ES TopK - docs, err := model_es.CreateDocESFactory().TopK(embedding, 1) + docs, err := model_es.CreateDocESFactory().TopK(embedding, 2) if err != nil || len(docs) == 0 { variable.ZapLog.Error("ES TopK error", zap.Error(err)) @@ -112,7 +117,7 @@ func (r *Rag) ChatSSE(context *gin.Context) { // 3. LLM answer go func() { - err := nlp.ChatRAG(docs[0].Content, query, ch, client) + err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client) if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } @@ -199,7 +204,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { } // 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。 - docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1) + docs, err := curd.TopK(mode, embedding, 1) if err != nil || len(docs) == 0 { variable.ZapLog.Error("ES TopK error", zap.Error(err)) @@ -214,14 +219,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { // STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。 // 0. 传递参考资料的信息 - docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs) + docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言 err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err)) } // 1. 传递 token 信息; // UPDATE 临时方案 - tokenMsg := model.CreateNlpWebSocketResult("token", token) + tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token) err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err)) @@ -234,7 +239,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { ch := make(chan string) // TIP 建立通道。 go func() { - err := nlp.ChatRAG(docs[0].Content, query, mode, ch, clientInfo.Client) + err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口 if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } diff --git a/app/model/rag_websocket_result.go b/app/model/rag_websocket_result.go index cb88984..f7d86dd 100644 --- a/app/model/rag_websocket_result.go +++ b/app/model/rag_websocket_result.go @@ -1,10 +1,13 @@ package model -import "encoding/json" +import ( + "catface/app/global/consts" + "encoding/json" +) func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult { if t == "" { - t = "chat" + t = consts.AiMessageTypeText } return &NlpWebSocketResult{ diff --git a/app/model_es/encounter.go b/app/model_es/encounter.go index d5d1187..d401752 100644 --- a/app/model_es/encounter.go +++ b/app/model_es/encounter.go @@ -142,32 +142,32 @@ func (e *Encounter) UpdateDocument(client *elasticsearch.Client, encounter *Enco */ func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) { body := fmt.Sprintf(`{ - "size": %d, - "query": { - "bool": { - "should": [ - {"match": {"tags": "%s"}}, - {"match": {"content": "%s"}}, - {"match": {"title": "%s"}} - ] - } - }, - "highlight": { - "pre_tags": ["%v"], - "post_tags": ["%v"], - "fields": { - "title": {}, - "content": { - "fragment_size" : 15 - }, - "tags": { - "pre_tags": [""], - "post_tags": [""] - } - } - }, - "_source": ["id", "title", "content", "tags"] -}`, num, query, query, query, consts.PreTags, consts.PostTags) + "size": %d, + "query": { + "bool": { + "should": [ + {"match": {"tags": "%s"}}, + {"match": {"content": "%s"}}, + {"match": {"title": "%s"}} + ] + } + }, + "highlight": { + "pre_tags": ["%v"], + "post_tags": ["%v"], + "fields": { + "title": {}, + "content": { + "fragment_size" : 15 + }, + "tags": { + "pre_tags": [""], + "post_tags": [""] + } + } + }, + "_source": ["id", "title", "content", "tags"] + }`, num, query, query, query, consts.PreTags, consts.PostTags) hits, err := model_handler.SearchRequest(body, e.IndexName()) if err != nil { @@ -188,3 +188,46 @@ func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, return encounters, nil } + +func (e *Encounter) TopK(embedding []float64, k int) ([]Encounter, error) { + // 同理 Doc + params := map[string]interface{}{ + "query_vector": embedding, + } + paramsJSON, err := json.Marshal(params) + if err != nil { + return nil, err + } + + body := fmt.Sprintf(`{ + "size": %d, + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": %s + } + } + }, + "_source":["id"] + }`, k, string(paramsJSON)) + + hits, err := model_handler.SearchRequest(body, e.IndexName()) + if err != nil { + return nil, err + } + + var encounters []Encounter + for _, hit := range hits { + hitMap := hit.(map[string]interface{}) + source := hitMap["_source"].(map[string]interface{}) + var encounter Encounter + if err := data_bind.ShouldBindFormMapToModel(source, &encounter); err != nil { + continue + } + + encounters = append(encounters, encounter) + } + return encounters, nil +} diff --git a/app/model_res/base_model.go b/app/model_res/base_model.go new file mode 100644 index 0000000..69bdf3b --- /dev/null +++ b/app/model_res/base_model.go @@ -0,0 +1,13 @@ +package model_res + +type DocInterface interface { + ToString() string +} + +type DocBase struct { + Type string `json:"type"` +} + +func (d DocBase) ToString() string { + return "" +} diff --git a/app/model_res/doc.go b/app/model_res/doc.go index 00845f9..554db73 100644 --- a/app/model_res/doc.go +++ b/app/model_res/doc.go @@ -6,10 +6,10 @@ import ( "time" ) -// BUG 存在 依賴循環 +// INFO 由于直接放到 model 中会导致循环引用,所以放到 model_res 中 func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult { return &DocResult{ - Type: "doc", + DocBase: DocBase{Type: "doc"}, Id: doc.Id, Name: doc.Name, Content: doc_es.Content, @@ -18,9 +18,22 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult { } type DocResult struct { - Type string `json:"type"` + DocBase Id int64 `json:"id"` Name string `json:"name"` Content string `json:"content"` UpdatedAt *time.Time `json:"updated_at"` } + +// GetType implements DocInterface. +func (d DocResult) GetType() string { + panic("unimplemented") +} + +/** + * @description: 实现 DocInterface 接口,输出作为 LLM 的参考内容。 + * @return {*} + */ +func (d DocResult) ToString() string { + return d.Content +} diff --git a/app/model_res/encounter.go b/app/model_res/encounter.go new file mode 100644 index 0000000..1929fa4 --- /dev/null +++ b/app/model_res/encounter.go @@ -0,0 +1,29 @@ +package model_res + +import ( + "catface/app/model" + "catface/app/model_es" + "fmt" + "time" +) + +func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encounter) *EncounterResult { + return &EncounterResult{ + DocBase: DocBase{Type: "encounter"}, + Id: encounter.Id, + Title: encounter.Title, + Content: encounter.Content, + UpdatedAt: encounter.UpdatedAt} +} + +type EncounterResult struct { + DocBase + Id int64 `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + UpdatedAt *time.Time `json:"updated_at"` +} + +func (e EncounterResult) ToString() string { + return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;`, e.Title, e.Content) +} diff --git a/app/service/rag/curd/base_curd.go b/app/service/rag/curd/base_curd.go new file mode 100644 index 0000000..a5765a2 --- /dev/null +++ b/app/service/rag/curd/base_curd.go @@ -0,0 +1,37 @@ +package curd + +import ( + "catface/app/global/consts" + "catface/app/model_res" + "fmt" +) + +func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) { + switch mode { + case consts.RagChatModeKnowledge: + results, err := CreateDocCurdFactory().TopK(embedding, k) + if err != nil { + return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err) + } + for _, result := range results { + temp = append(temp, result) + } + + case consts.RagChatModeDiary: + results, err := CreateEncounterCurdFactory().TopK(embedding, k) + if err != nil { + return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err) + } + for _, result := range results { + temp = append(temp, result) + } + + default: + if mode == "" { + err = fmt.Errorf("TopK: mode不能为空") + } else { + err = fmt.Errorf("TopK: 不支持的mode: %s", mode) + } + } + return temp, err +} diff --git a/app/service/rag/curd/encounter_curd.go b/app/service/rag/curd/encounter_curd.go new file mode 100644 index 0000000..dede0a8 --- /dev/null +++ b/app/service/rag/curd/encounter_curd.go @@ -0,0 +1,42 @@ +package curd + +import ( + "catface/app/model" + "catface/app/model_es" + "catface/app/model_res" +) + +func CreateEncounterCurdFactory() *EncounterCurd { + return &EncounterCurd{ + enc: model.CreateEncounterFactory(""), + enc_es: model_es.CreateEncounterESFactory(nil), + } +} + +type EncounterCurd struct { + enc *model.Encounter + enc_es *model_es.Encounter +} + +func (e *EncounterCurd) TopK(embedding []float64, k int) (temp []model_res.EncounterResult, err error) { + // ES: TopK + encounters_es, err := e.enc_es.TopK(embedding, k) + if err != nil { + return + } + + // MySQL 补充信息 + var ids []int64 + for _, encounter := range encounters_es { + ids = append(ids, encounter.Id) + } + encounters := e.enc.ShowByIDs(ids, "id", "title", "content", "updated_at") + for _, encounter := range encounters { + for _, encounter_es := range encounters_es { + if encounter.Id == encounter_es.Id { + temp = append(temp, *model_res.NewEncounterResult(&encounter, &encounter_es)) + } + } + } + return +} diff --git a/config/prompts.yml b/config/prompts.yml index c32116e..4cbd1c5 100644 --- a/config/prompts.yml +++ b/config/prompts.yml @@ -22,4 +22,11 @@ Prompt: ··· {context} ··· - 如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。" \ No newline at end of file + 如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。" + Diary: "使用以知识库中找到的猫猫路遇日记来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。 + 问题: {question} + 可参考的路遇日记: + ··· + {context} + ··· + 如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。" \ No newline at end of file