🎨 refactor(rag): 重构 RAG 模型相关代码
- 重构了 rag_controller.go 中的逻辑,使用新的 DocumentHub 结构 - 修改了 encounter.go 中的 Encounter 结构,增加了 explain 标签 - 重写了 rag_websocket.go 中的逻辑,使用新的 DocumentHub 结构 - 新增了 curd_es/encounter_es_curd.go 文件,实现了 Encounter 的 CURD 操作 - 更新了 nlp/func.go 中的 ChatRAG 函数,使用新的 DocumentHub 结构 - 新增了 curd/docs_hub.go 文件,实现了 DocumentHub 的 TopK 方法 - 新增了 utils/data_explain/data_explain_rag.go 文件,实现了结构体到解释字符串的转换
This commit is contained in:
parent
81cd287109
commit
ae7edb5e8d
@ -4,8 +4,7 @@ import (
|
|||||||
"catface/app/global/consts"
|
"catface/app/global/consts"
|
||||||
"catface/app/global/errcode"
|
"catface/app/global/errcode"
|
||||||
"catface/app/global/variable"
|
"catface/app/global/variable"
|
||||||
"catface/app/model"
|
"catface/app/model_res"
|
||||||
"catface/app/model_es"
|
|
||||||
"catface/app/service/nlp"
|
"catface/app/service/nlp"
|
||||||
"catface/app/service/rag/curd"
|
"catface/app/service/rag/curd"
|
||||||
"catface/app/utils/llm_factory"
|
"catface/app/utils/llm_factory"
|
||||||
@ -103,8 +102,8 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. ES TopK
|
// 2. ES TopK
|
||||||
docs, err := model_es.CreateDocESFactory().TopK(embedding, 2)
|
dochub, err := curd.TopK(mode, embedding, 1)
|
||||||
if err != nil || len(docs) == 0 {
|
if err != nil || dochub.Length() == 0 {
|
||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
code := errcode.ErrNoDocFound
|
code := errcode.ErrNoDocFound
|
||||||
@ -117,7 +116,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
|
|
||||||
// 3. LLM answer
|
// 3. LLM answer
|
||||||
go func() {
|
go func() {
|
||||||
err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client)
|
err := nlp.ChatRAG(query, mode, dochub, ch, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -172,7 +171,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
// 0-2. 测试 Python 微服务是否启动
|
// 0-2. 测试 Python 微服务是否启动
|
||||||
if !micro_service.TestLinkPythonService() {
|
if !micro_service.TestLinkPythonService() {
|
||||||
code := errcode.ErrPythonServierDown
|
code := errcode.ErrPythonServierDown
|
||||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -183,7 +182,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
|
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
|
||||||
if ercode != 0 {
|
if ercode != 0 {
|
||||||
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
|
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
|
||||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -196,7 +195,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
|
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
|
||||||
if !ok {
|
if !ok {
|
||||||
code := errcode.ErrPythonServierDown
|
code := errcode.ErrPythonServierDown
|
||||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -204,12 +203,12 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
||||||
docs, err := curd.TopK(mode, embedding, 1)
|
dochub, err := curd.TopK(mode, embedding, 1)
|
||||||
if err != nil || len(docs) == 0 {
|
if err != nil || dochub.Length() == 0 {
|
||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
code := errcode.ErrNoDocFound
|
code := errcode.ErrNoDocFound
|
||||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -219,14 +218,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
||||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||||
// 0. 传递参考资料的信息
|
// 0. 传递参考资料的信息
|
||||||
docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言
|
docMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, dochub.Docs) // TIP 断言
|
||||||
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 传递 token 信息; // UPDATE 临时方案
|
// 1. 传递 token 信息; // UPDATE 临时方案
|
||||||
tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
|
tokenMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
|
||||||
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||||
@ -239,7 +238,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
ch := make(chan string) // TIP 建立通道。
|
ch := make(chan string) // TIP 建立通道。
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口
|
err := nlp.ChatRAG(query, mode, dochub, ch, clientInfo.Client) // TIP 接口
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -253,7 +252,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
|
||||||
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", c).JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -264,5 +263,5 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
||||||
// TODO 也许也可以同样掉上面那个接口了。
|
// TODO 也许也可以同样用上面那个接口了。
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"catface/app/global/consts"
|
"catface/app/global/consts"
|
||||||
"catface/app/global/variable"
|
"catface/app/global/variable"
|
||||||
"catface/app/model"
|
"catface/app/model"
|
||||||
"catface/app/service/nlp"
|
|
||||||
"catface/app/utils/data_bind"
|
"catface/app/utils/data_bind"
|
||||||
"catface/app/utils/model_handler"
|
"catface/app/utils/model_handler"
|
||||||
"context"
|
"context"
|
||||||
@ -33,13 +32,12 @@ func CreateEncounterESFactory(encounter *model.Encounter) *Encounter {
|
|||||||
// INFO 存储能够作为索引存在的数据。
|
// INFO 存储能够作为索引存在的数据。
|
||||||
type Encounter struct {
|
type Encounter struct {
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title" explain:"路遇笔记标题"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content" explain:"内容"`
|
||||||
Tags []string `json:"tags"`
|
Tags []string `json:"tags" explain:"标签"`
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
|
|
||||||
// TagsHighlight []string `json:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。
|
TagsHighlight []string `json:"-" bind:"tags_highlight"`
|
||||||
TagsHighlight []string `json:"-" bind:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Encounter) IndexName() string {
|
func (e *Encounter) IndexName() string {
|
||||||
@ -49,11 +47,6 @@ func (e *Encounter) IndexName() string {
|
|||||||
func (e *Encounter) InsertDocument() error {
|
func (e *Encounter) InsertDocument() error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
var ok bool
|
|
||||||
if e.Embedding, ok = nlp.GetEmbedding([]string{e.Title, e.Content}); !ok {
|
|
||||||
return fmt.Errorf("nlp embedding service error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 将结构体转换为 JSON 字符串
|
// 将结构体转换为 JSON 字符串
|
||||||
data, err := json.Marshal(e)
|
data, err := json.Marshal(e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -20,9 +20,9 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
|||||||
type DocResult struct {
|
type DocResult struct {
|
||||||
DocBase
|
DocBase
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name" explain:"文档名称"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content" explain:"文档内容"`
|
||||||
UpdatedAt *time.Time `json:"updated_at"`
|
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetType implements DocInterface.
|
// GetType implements DocInterface.
|
||||||
|
@ -19,11 +19,11 @@ func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encou
|
|||||||
type EncounterResult struct {
|
type EncounterResult struct {
|
||||||
DocBase
|
DocBase
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title" explain:"路遇笔记标题"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content" explain:"内容"`
|
||||||
UpdatedAt *time.Time `json:"updated_at"`
|
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e EncounterResult) ToString() string {
|
func (e EncounterResult) ToString() string {
|
||||||
return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;`, e.Title, e.Content)
|
return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, e.Title, e.Content, e.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package model
|
package model_res
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"catface/app/global/consts"
|
"catface/app/global/consts"
|
29
app/service/encounter/curd_es/encounter_es_curd.go
Normal file
29
app/service/encounter/curd_es/encounter_es_curd.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package curd_es
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/model"
|
||||||
|
"catface/app/model_es"
|
||||||
|
"catface/app/service/nlp"
|
||||||
|
"catface/app/utils/data_explain"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateEncounterESCurdFactory(encounter *model.Encounter) *EncounterESCurd {
|
||||||
|
return &EncounterESCurd{
|
||||||
|
model_es.CreateEncounterESFactory(encounter),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncounterESCurd struct {
|
||||||
|
encounter_es *model_es.Encounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EncounterESCurd) InsertDocument() error {
|
||||||
|
var ok bool
|
||||||
|
explian := data_explain.GenerateExplainStringForEmbedding(e.encounter_es)
|
||||||
|
if e.encounter_es.Embedding, ok = nlp.GetEmbeddingOneString(explian); !ok {
|
||||||
|
return fmt.Errorf("nlp embedding service error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.encounter_es.InsertDocument()
|
||||||
|
}
|
@ -32,3 +32,7 @@ func GetEmbedding(text []string) ([]float64, bool) {
|
|||||||
return res.Embedding, true
|
return res.Embedding, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetEmbeddingOneString(str string) ([]float64, bool) {
|
||||||
|
return GetEmbedding([]string{str})
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package nlp
|
|||||||
import (
|
import (
|
||||||
"catface/app/global/variable"
|
"catface/app/global/variable"
|
||||||
"catface/app/service/nlp/glm"
|
"catface/app/service/nlp/glm"
|
||||||
|
"catface/app/service/rag/curd"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -16,13 +17,13 @@ func GenerateTitle(content string, client *zhipu.ChatCompletionService) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
|
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
|
||||||
func ChatRAG(doc, query, mode string, ch chan<- string, client *zhipu.ChatCompletionService) error {
|
func ChatRAG(query, mode string, dochub curd.DocumentHub, ch chan<- string, client *zhipu.ChatCompletionService) error {
|
||||||
// 读取配置文件中的 KnoledgeRAG 模板
|
// 读取配置文件中的 KnoledgeRAG 模板
|
||||||
promptTemplate := variable.PromptsYml.GetString("Prompt.RAG." + mode)
|
promptTemplate := variable.PromptsYml.GetString("Prompt.RAG." + mode)
|
||||||
|
|
||||||
// 替换模板中的占位符
|
// 替换模板中的占位符
|
||||||
message := strings.Replace(promptTemplate, "{question}", query, -1)
|
message := strings.Replace(promptTemplate, "{question}", query, -1)
|
||||||
message = strings.Replace(message, "{context}", doc, -1)
|
message = strings.Replace(message, "{context}", dochub.Explain4LLM(), -1)
|
||||||
|
|
||||||
// 调用聊天接口
|
// 调用聊天接口
|
||||||
// err := glm.ChatStream(message, ch)
|
// err := glm.ChatStream(message, ch)
|
||||||
|
@ -1,37 +0,0 @@
|
|||||||
package curd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"catface/app/global/consts"
|
|
||||||
"catface/app/model_res"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) {
|
|
||||||
switch mode {
|
|
||||||
case consts.RagChatModeKnowledge:
|
|
||||||
results, err := CreateDocCurdFactory().TopK(embedding, k)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err)
|
|
||||||
}
|
|
||||||
for _, result := range results {
|
|
||||||
temp = append(temp, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
case consts.RagChatModeDiary:
|
|
||||||
results, err := CreateEncounterCurdFactory().TopK(embedding, k)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err)
|
|
||||||
}
|
|
||||||
for _, result := range results {
|
|
||||||
temp = append(temp, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
if mode == "" {
|
|
||||||
err = fmt.Errorf("TopK: mode不能为空")
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return temp, err
|
|
||||||
}
|
|
61
app/service/rag/curd/docs_hub.go
Normal file
61
app/service/rag/curd/docs_hub.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package curd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/global/consts"
|
||||||
|
"catface/app/utils/data_explain"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description: 作为搜索到的文档的集合,目前都是单一的文档类型;// TODO 如何更好的支持多文档的 TopK ?
|
||||||
|
* 相当于 DocumentHub 的构造函数。
|
||||||
|
* @param {string} mode
|
||||||
|
* @param {[]float64} embedding
|
||||||
|
* @param {int} k
|
||||||
|
* @return {*}
|
||||||
|
*/
|
||||||
|
func TopK(mode string, embedding []float64, k int) (dochub DocumentHub, err error) {
|
||||||
|
switch mode {
|
||||||
|
case consts.RagChatModeKnowledge:
|
||||||
|
results, errTemp := CreateDocCurdFactory().TopK(embedding, k)
|
||||||
|
if errTemp != nil {
|
||||||
|
err = fmt.Errorf("TopK: 获取知识库TopK失败: %w", errTemp)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
dochub.Docs = append(dochub.Docs, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
case consts.RagChatModeDiary:
|
||||||
|
results, errTemp := CreateEncounterCurdFactory().TopK(embedding, k)
|
||||||
|
if errTemp != nil {
|
||||||
|
err = fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", errTemp)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
dochub.Docs = append(dochub.Docs, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if mode == "" {
|
||||||
|
err = fmt.Errorf("TopK: mode不能为空")
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type DocumentHub struct {
|
||||||
|
Docs []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DocumentHub) Length() int {
|
||||||
|
return len(d.Docs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DocumentHub) Explain4LLM() (explain string) {
|
||||||
|
for id, doc := range d.Docs {
|
||||||
|
explain += strconv.Itoa(id) + "." + data_explain.GenerateExplainStringForEmbedding(doc) + "\n"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
29
app/utils/data_explain/data_explain_rag.go
Normal file
29
app/utils/data_explain/data_explain_rag.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package data_explain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description: 集成 Struct -> Explain for RAG;
|
||||||
|
* @param {interface{}} v
|
||||||
|
* @return {*}
|
||||||
|
*/
|
||||||
|
func GenerateExplainStringForEmbedding(v interface{}) string {
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
tag := field.Tag.Get("explain")
|
||||||
|
if tag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value := val.Field(i).Interface()
|
||||||
|
result = append(result, fmt.Sprintf("%s:%v", tag, value))
|
||||||
|
}
|
||||||
|
return strings.Join(result, ";")
|
||||||
|
}
|
44
test/encounter_test.go
Normal file
44
test/encounter_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 假设的 EncounterResult 结构体
|
||||||
|
type EncounterResult struct {
|
||||||
|
Title string
|
||||||
|
Content string
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToString 方法
|
||||||
|
func (e EncounterResult) ToString() string {
|
||||||
|
return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, e.Title, e.Content, e.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试 EncounterResult 的 ToString 方法
|
||||||
|
func TestEncounterResult_ToString(t *testing.T) {
|
||||||
|
// 设置一个时间点,用于测试
|
||||||
|
testTime := time.Now()
|
||||||
|
|
||||||
|
// 创建一个 EncounterResult 实例
|
||||||
|
testResult := EncounterResult{
|
||||||
|
Title: "测试笔记",
|
||||||
|
Content: "这是测试笔记的内容",
|
||||||
|
UpdatedAt: testTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 ToString 方法
|
||||||
|
resultString := testResult.ToString()
|
||||||
|
|
||||||
|
t.Log("resultString:", resultString)
|
||||||
|
// 构建期望的结果字符串
|
||||||
|
expectedString := fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;最后更新时间:%v`, testResult.Title, testResult.Content, testResult.UpdatedAt)
|
||||||
|
|
||||||
|
// 比较实际结果和期望结果
|
||||||
|
if resultString != expectedString {
|
||||||
|
t.Errorf("ToString() failed, expected %q, got %q", expectedString, resultString)
|
||||||
|
}
|
||||||
|
}
|
46
test/explain_test.go
Normal file
46
test/explain_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncounterResult2 结构体定义
|
||||||
|
type EncounterResult2 struct {
|
||||||
|
Id int64 `json:"id" explain:"路遇笔记ID"`
|
||||||
|
Title string `json:"title" explain:"路遇笔记标题"`
|
||||||
|
Content string `json:"content" explain:"路遇笔记内容"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructToString 使用反射将结构体的内容组织为字符串
|
||||||
|
func StructToString(v interface{}) string {
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
tag := field.Tag.Get("explain")
|
||||||
|
value := val.Field(i).Interface()
|
||||||
|
result = append(result, fmt.Sprintf("%s:%v", tag, value))
|
||||||
|
}
|
||||||
|
return strings.Join(result, ";")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExplain(t *testing.T) {
|
||||||
|
// 示例数据
|
||||||
|
updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
encounter := EncounterResult2{
|
||||||
|
Id: 1,
|
||||||
|
Title: "遇见小猫",
|
||||||
|
Content: "今天在公园遇到了一只可爱的小猫。",
|
||||||
|
UpdatedAt: &updatedAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 StructToString 函数
|
||||||
|
t.Logf("结构体内容:", StructToString(encounter))
|
||||||
|
}
|
50
test/temp/explain.go
Normal file
50
test/temp/explain.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncounterResult 结构体定义
|
||||||
|
type EncounterResult struct {
|
||||||
|
Id int64 `json:"id"`
|
||||||
|
Title string `json:"title" explain:"路遇笔记标题"`
|
||||||
|
Content string `json:"content" explain:"路遇笔记内容"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
|
||||||
|
NoTag string `json:"no_tag"` // 没有 explain 标签的字段
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructToString 使用反射将结构体的内容组织为字符串,忽略没有 explain 标签的字段
|
||||||
|
func StructToString(v interface{}) string {
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
tag := field.Tag.Get("explain")
|
||||||
|
if tag == "" {
|
||||||
|
continue // 跳过没有 explain 标签的字段
|
||||||
|
}
|
||||||
|
value := val.Field(i).Interface()
|
||||||
|
result = append(result, fmt.Sprintf("%s:%v", tag, value))
|
||||||
|
}
|
||||||
|
return strings.Join(result, ";")
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 示例数据
|
||||||
|
updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
encounter := EncounterResult{
|
||||||
|
Id: 1,
|
||||||
|
Title: "遇见小猫",
|
||||||
|
Content: "今天在公园遇到了一只可爱的小猫。",
|
||||||
|
UpdatedAt: &updatedAt,
|
||||||
|
NoTag: "这个字段没有 explain 标签",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 StructToString 函数
|
||||||
|
fmt.Println(StructToString(encounter))
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user