Havoc412 ae7edb5e8d 🎨 refactor(rag): 重构 RAG 模型相关代码
- 重构了 rag_controller.go 中的逻辑,使用新的 DocumentHub 结构
- 修改了 encounter.go 中的 Encounter 结构,增加了 explain 标签
- 重写了 rag_websocket.go 中的逻辑,使用新的 DocumentHub 结构
- 新增了 curd_es/encounter_es_curd.go 文件,实现了 Encounter 的 CURD 操作
- 更新了 nlp/func.go 中的 ChatRAG 函数,使用新的 DocumentHub 结构
- 新增了 curd/docs_hub.go 文件,实现了 DocumentHub 的 TopK 方法
- 新增了 utils/data_explain/data_explain_rag.go 文件,实现了结构体到解释字符串的转换
2024-11-20 19:30:11 +08:00

62 lines
1.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package curd
import (
"catface/app/global/consts"
"catface/app/utils/data_explain"
"fmt"
"strconv"
)
/**
* @description: 作为搜索到的文档的集合,目前都是单一的文档类型;// TODO 如何更好的支持多文档的 TopK
* 相当于 DocumentHub 的构造函数。
* @param {string} mode
* @param {[]float64} embedding
* @param {int} k
* @return {*}
*/
func TopK(mode string, embedding []float64, k int) (dochub DocumentHub, err error) {
switch mode {
case consts.RagChatModeKnowledge:
results, errTemp := CreateDocCurdFactory().TopK(embedding, k)
if errTemp != nil {
err = fmt.Errorf("TopK: 获取知识库TopK失败: %w", errTemp)
}
for _, result := range results {
dochub.Docs = append(dochub.Docs, result)
}
case consts.RagChatModeDiary:
results, errTemp := CreateEncounterCurdFactory().TopK(embedding, k)
if errTemp != nil {
err = fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", errTemp)
}
for _, result := range results {
dochub.Docs = append(dochub.Docs, result)
}
default:
if mode == "" {
err = fmt.Errorf("TopK: mode不能为空")
} else {
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
}
}
return
}
type DocumentHub struct {
Docs []interface{}
}
func (d *DocumentHub) Length() int {
return len(d.Docs)
}
func (d *DocumentHub) Explain4LLM() (explain string) {
for id, doc := range d.Docs {
explain += strconv.Itoa(id) + "." + data_explain.GenerateExplainStringForEmbedding(doc) + "\n"
}
return
}