diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go index b683da1..48f9d35 100644 --- a/app/http/controller/web/rag_controller.go +++ b/app/http/controller/web/rag_controller.go @@ -4,8 +4,7 @@ import ( "catface/app/global/consts" "catface/app/global/errcode" "catface/app/global/variable" - "catface/app/model" - "catface/app/model_es" + "catface/app/model_res" "catface/app/service/nlp" "catface/app/service/rag/curd" "catface/app/utils/llm_factory" @@ -103,8 +102,8 @@ func (r *Rag) ChatSSE(context *gin.Context) { } // 2. ES TopK - docs, err := model_es.CreateDocESFactory().TopK(embedding, 2) - if err != nil || len(docs) == 0 { + dochub, err := curd.TopK(mode, embedding, 1) + if err != nil || dochub.Length() == 0 { variable.ZapLog.Error("ES TopK error", zap.Error(err)) code := errcode.ErrNoDocFound @@ -117,7 +116,7 @@ func (r *Rag) ChatSSE(context *gin.Context) { // 3. LLM answer go func() { - err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client) + err := nlp.ChatRAG(query, mode, dochub, ch, client) if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } @@ -172,7 +171,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { // 0-2. 测试 Python 微服务是否启动 if !micro_service.TestLinkPythonService() { code := errcode.ErrPythonServierDown - err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) + err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -183,7 +182,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub) if ercode != 0 { variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err)) - err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal()) + err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -196,7 +195,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys) if !ok { code := errcode.ErrPythonServierDown - err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) + err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -204,12 +203,12 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { } // 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。 - docs, err := curd.TopK(mode, embedding, 1) - if err != nil || len(docs) == 0 { + dochub, err := curd.TopK(mode, embedding, 1) + if err != nil || dochub.Length() == 0 { variable.ZapLog.Error("ES TopK error", zap.Error(err)) code := errcode.ErrNoDocFound - err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) + err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err)) } @@ -219,14 +218,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { // STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。 // 0. 传递参考资料的信息 - docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言 + docMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, dochub.Docs) // TIP 断言 err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err)) } // 1. 传递 token 信息; // UPDATE 临时方案 - tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token) + tokenMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token) err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal()) if err != nil { variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err)) @@ -239,7 +238,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { ch := make(chan string) // TIP 建立通道。 go func() { - err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口 + err := nlp.ChatRAG(query, mode, dochub, ch, clientInfo.Client) // TIP 接口 if err != nil { variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err)) } @@ -253,7 +252,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { return } // variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c)) - err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal()) + err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", c).JsonMarshal()) if err != nil { return } @@ -264,5 +263,5 @@ func (r *Rag) ChatWebSocket(context *gin.Context) { } func (r *Rag) HelpDetectCat(context *gin.Context) { - // TODO 也许也可以同样掉上面那个接口了。 + // TODO 也许也可以同样用上面那个接口了。 } diff --git a/app/model_es/encounter.go b/app/model_es/encounter.go index d401752..3edaf74 100644 --- a/app/model_es/encounter.go +++ b/app/model_es/encounter.go @@ -5,7 +5,6 @@ import ( "catface/app/global/consts" "catface/app/global/variable" "catface/app/model" - "catface/app/service/nlp" "catface/app/utils/data_bind" "catface/app/utils/model_handler" "context" @@ -33,13 +32,12 @@ func CreateEncounterESFactory(encounter *model.Encounter) *Encounter { // INFO 存储能够作为索引存在的数据。 type Encounter struct { Id int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` - Tags []string `json:"tags"` + Title string `json:"title" explain:"路遇笔记标题"` + Content string `json:"content" explain:"内容"` + Tags []string `json:"tags" explain:"标签"` Embedding []float64 `json:"embedding"` - // TagsHighlight []string `json:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。 - TagsHighlight []string `json:"-" bind:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。 + TagsHighlight []string `json:"-" bind:"tags_highlight"` } func (e *Encounter) IndexName() string { @@ -49,11 +47,6 @@ func (e *Encounter) IndexName() string { func (e *Encounter) InsertDocument() error { ctx := context.Background() - var ok bool - if e.Embedding, ok = nlp.GetEmbedding([]string{e.Title, e.Content}); !ok { - return fmt.Errorf("nlp embedding service error") - } - // 将结构体转换为 JSON 字符串 data, err := json.Marshal(e) if err != nil { diff --git a/app/model_res/doc.go b/app/model_res/doc.go index 554db73..8a44a1d 100644 --- a/app/model_res/doc.go +++ b/app/model_res/doc.go @@ -20,9 +20,9 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult { type DocResult struct { DocBase Id int64 `json:"id"` - Name string `json:"name"` - Content string `json:"content"` - UpdatedAt *time.Time `json:"updated_at"` + Name string `json:"name" explain:"文档名称"` + Content string `json:"content" explain:"文档内容"` + UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"` } // GetType implements DocInterface. diff --git a/app/model_res/encounter.go b/app/model_res/encounter.go index 1929fa4..73cca1d 100644 --- a/app/model_res/encounter.go +++ b/app/model_res/encounter.go @@ -19,11 +19,11 @@ func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encou type EncounterResult struct { DocBase Id int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` - UpdatedAt *time.Time `json:"updated_at"` + Title string `json:"title" explain:"路遇笔记标题"` + Content string `json:"content" explain:"内容"` + UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"` } func (e EncounterResult) ToString() string { - return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;`, e.Title, e.Content) + return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, e.Title, e.Content, e.UpdatedAt) } diff --git a/app/model/rag_websocket_result.go b/app/model_res/rag_websocket.go similarity index 95% rename from app/model/rag_websocket_result.go rename to app/model_res/rag_websocket.go index f7d86dd..3b537ef 100644 --- a/app/model/rag_websocket_result.go +++ b/app/model_res/rag_websocket.go @@ -1,4 +1,4 @@ -package model +package model_res import ( "catface/app/global/consts" diff --git a/app/service/encounter/curd_es/encounter_es_curd.go b/app/service/encounter/curd_es/encounter_es_curd.go new file mode 100644 index 0000000..5c37e06 --- /dev/null +++ b/app/service/encounter/curd_es/encounter_es_curd.go @@ -0,0 +1,29 @@ +package curd_es + +import ( + "catface/app/model" + "catface/app/model_es" + "catface/app/service/nlp" + "catface/app/utils/data_explain" + "fmt" +) + +func CreateEncounterESCurdFactory(encounter *model.Encounter) *EncounterESCurd { + return &EncounterESCurd{ + model_es.CreateEncounterESFactory(encounter), + } +} + +type EncounterESCurd struct { + encounter_es *model_es.Encounter +} + +func (e *EncounterESCurd) InsertDocument() error { + var ok bool + explian := data_explain.GenerateExplainStringForEmbedding(e.encounter_es) + if e.encounter_es.Embedding, ok = nlp.GetEmbeddingOneString(explian); !ok { + return fmt.Errorf("nlp embedding service error") + } + + return e.encounter_es.InsertDocument() +} diff --git a/app/service/nlp/embedding.go b/app/service/nlp/embedding.go index fbddac5..eeacc3f 100644 --- a/app/service/nlp/embedding.go +++ b/app/service/nlp/embedding.go @@ -32,3 +32,7 @@ func GetEmbedding(text []string) ([]float64, bool) { return res.Embedding, true } } + +func GetEmbeddingOneString(str string) ([]float64, bool) { + return GetEmbedding([]string{str}) +} diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go index 54bce8b..1a4245b 100644 --- a/app/service/nlp/func.go +++ b/app/service/nlp/func.go @@ -3,6 +3,7 @@ package nlp import ( "catface/app/global/variable" "catface/app/service/nlp/glm" + "catface/app/service/rag/curd" "fmt" "strings" @@ -16,13 +17,13 @@ func GenerateTitle(content string, client *zhipu.ChatCompletionService) string { } // ChatKnoledgeRAG 使用 RAG 模型进行知识问答 -func ChatRAG(doc, query, mode string, ch chan<- string, client *zhipu.ChatCompletionService) error { +func ChatRAG(query, mode string, dochub curd.DocumentHub, ch chan<- string, client *zhipu.ChatCompletionService) error { // 读取配置文件中的 KnoledgeRAG 模板 promptTemplate := variable.PromptsYml.GetString("Prompt.RAG." + mode) // 替换模板中的占位符 message := strings.Replace(promptTemplate, "{question}", query, -1) - message = strings.Replace(message, "{context}", doc, -1) + message = strings.Replace(message, "{context}", dochub.Explain4LLM(), -1) // 调用聊天接口 // err := glm.ChatStream(message, ch) diff --git a/app/service/rag/curd/base_curd.go b/app/service/rag/curd/base_curd.go deleted file mode 100644 index a5765a2..0000000 --- a/app/service/rag/curd/base_curd.go +++ /dev/null @@ -1,37 +0,0 @@ -package curd - -import ( - "catface/app/global/consts" - "catface/app/model_res" - "fmt" -) - -func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) { - switch mode { - case consts.RagChatModeKnowledge: - results, err := CreateDocCurdFactory().TopK(embedding, k) - if err != nil { - return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err) - } - for _, result := range results { - temp = append(temp, result) - } - - case consts.RagChatModeDiary: - results, err := CreateEncounterCurdFactory().TopK(embedding, k) - if err != nil { - return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err) - } - for _, result := range results { - temp = append(temp, result) - } - - default: - if mode == "" { - err = fmt.Errorf("TopK: mode不能为空") - } else { - err = fmt.Errorf("TopK: 不支持的mode: %s", mode) - } - } - return temp, err -} diff --git a/app/service/rag/curd/docs_hub.go b/app/service/rag/curd/docs_hub.go new file mode 100644 index 0000000..c999051 --- /dev/null +++ b/app/service/rag/curd/docs_hub.go @@ -0,0 +1,61 @@ +package curd + +import ( + "catface/app/global/consts" + "catface/app/utils/data_explain" + "fmt" + "strconv" +) + +/** + * @description: 作为搜索到的文档的集合,目前都是单一的文档类型;// TODO 如何更好的支持多文档的 TopK ? + * 相当于 DocumentHub 的构造函数。 + * @param {string} mode + * @param {[]float64} embedding + * @param {int} k + * @return {*} + */ +func TopK(mode string, embedding []float64, k int) (dochub DocumentHub, err error) { + switch mode { + case consts.RagChatModeKnowledge: + results, errTemp := CreateDocCurdFactory().TopK(embedding, k) + if errTemp != nil { + err = fmt.Errorf("TopK: 获取知识库TopK失败: %w", errTemp) + } + for _, result := range results { + dochub.Docs = append(dochub.Docs, result) + } + + case consts.RagChatModeDiary: + results, errTemp := CreateEncounterCurdFactory().TopK(embedding, k) + if errTemp != nil { + err = fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", errTemp) + } + for _, result := range results { + dochub.Docs = append(dochub.Docs, result) + } + + default: + if mode == "" { + err = fmt.Errorf("TopK: mode不能为空") + } else { + err = fmt.Errorf("TopK: 不支持的mode: %s", mode) + } + } + return +} + +type DocumentHub struct { + Docs []interface{} +} + +func (d *DocumentHub) Length() int { + return len(d.Docs) +} + +func (d *DocumentHub) Explain4LLM() (explain string) { + for id, doc := range d.Docs { + explain += strconv.Itoa(id) + "." + data_explain.GenerateExplainStringForEmbedding(doc) + "\n" + } + return +} diff --git a/app/utils/data_explain/data_explain_rag.go b/app/utils/data_explain/data_explain_rag.go new file mode 100644 index 0000000..7ec74bf --- /dev/null +++ b/app/utils/data_explain/data_explain_rag.go @@ -0,0 +1,29 @@ +package data_explain + +import ( + "fmt" + "reflect" + "strings" +) + +/** + * @description: 集成 Struct -> Explain for RAG; + * @param {interface{}} v + * @return {*} + */ +func GenerateExplainStringForEmbedding(v interface{}) string { + val := reflect.ValueOf(v) + typ := val.Type() + + var result []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + tag := field.Tag.Get("explain") + if tag == "" { + continue + } + value := val.Field(i).Interface() + result = append(result, fmt.Sprintf("%s:%v", tag, value)) + } + return strings.Join(result, ";") +} diff --git a/test/encounter_test.go b/test/encounter_test.go new file mode 100644 index 0000000..c97eb90 --- /dev/null +++ b/test/encounter_test.go @@ -0,0 +1,44 @@ +package test + +import ( + "fmt" + "testing" + "time" +) + +// 假设的 EncounterResult 结构体 +type EncounterResult struct { + Title string + Content string + UpdatedAt time.Time +} + +// ToString 方法 +func (e EncounterResult) ToString() string { + return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, e.Title, e.Content, e.UpdatedAt) +} + +// 测试 EncounterResult 的 ToString 方法 +func TestEncounterResult_ToString(t *testing.T) { + // 设置一个时间点,用于测试 + testTime := time.Now() + + // 创建一个 EncounterResult 实例 + testResult := EncounterResult{ + Title: "测试笔记", + Content: "这是测试笔记的内容", + UpdatedAt: testTime, + } + + // 调用 ToString 方法 + resultString := testResult.ToString() + + t.Log("resultString:", resultString) + // 构建期望的结果字符串 + expectedString := fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, testResult.Title, testResult.Content, testResult.UpdatedAt) + + // 比较实际结果和期望结果 + if resultString != expectedString { + t.Errorf("ToString() failed, expected %q, got %q", expectedString, resultString) + } +} diff --git a/test/explain_test.go b/test/explain_test.go new file mode 100644 index 0000000..c7e897a --- /dev/null +++ b/test/explain_test.go @@ -0,0 +1,46 @@ +package test + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" +) + +// EncounterResult2 结构体定义 +type EncounterResult2 struct { + Id int64 `json:"id" explain:"路遇笔记ID"` + Title string `json:"title" explain:"路遇笔记标题"` + Content string `json:"content" explain:"路遇笔记内容"` + UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"` +} + +// StructToString 使用反射将结构体的内容组织为字符串 +func StructToString(v interface{}) string { + val := reflect.ValueOf(v) + typ := val.Type() + + var result []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + tag := field.Tag.Get("explain") + value := val.Field(i).Interface() + result = append(result, fmt.Sprintf("%s:%v", tag, value)) + } + return strings.Join(result, ";") +} + +func TestExplain(t *testing.T) { + // 示例数据 + updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC) + encounter := EncounterResult2{ + Id: 1, + Title: "遇见小猫", + Content: "今天在公园遇到了一只可爱的小猫。", + UpdatedAt: &updatedAt, + } + + // 调用 StructToString 函数 + t.Logf("结构体内容:", StructToString(encounter)) +} diff --git a/test/temp/explain.go b/test/temp/explain.go new file mode 100644 index 0000000..af98d65 --- /dev/null +++ b/test/temp/explain.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +// EncounterResult 结构体定义 +type EncounterResult struct { + Id int64 `json:"id"` + Title string `json:"title" explain:"路遇笔记标题"` + Content string `json:"content" explain:"路遇笔记内容"` + UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"` + NoTag string `json:"no_tag"` // 没有 explain 标签的字段 +} + +// StructToString 使用反射将结构体的内容组织为字符串,忽略没有 explain 标签的字段 +func StructToString(v interface{}) string { + val := reflect.ValueOf(v) + typ := val.Type() + + var result []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + tag := field.Tag.Get("explain") + if tag == "" { + continue // 跳过没有 explain 标签的字段 + } + value := val.Field(i).Interface() + result = append(result, fmt.Sprintf("%s:%v", tag, value)) + } + return strings.Join(result, ";") +} + +func main() { + // 示例数据 + updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC) + encounter := EncounterResult{ + Id: 1, + Title: "遇见小猫", + Content: "今天在公园遇到了一只可爱的小猫。", + UpdatedAt: &updatedAt, + NoTag: "这个字段没有 explain 标签", + } + + // 调用 StructToString 函数 + fmt.Println(StructToString(encounter)) +}