🎨 refactor(rag): 重构 RAG 模型相关代码

- 重构了 rag_controller.go 中的逻辑,使用新的 DocumentHub 结构
- 修改了 encounter.go 中的 Encounter 结构,增加了 explain 标签
- 重写了 rag_websocket.go 中的逻辑,使用新的 DocumentHub 结构
- 新增了 curd_es/encounter_es_curd.go 文件,实现了 Encounter 的 CURD 操作
- 更新了 nlp/func.go 中的 ChatRAG 函数,使用新的 DocumentHub 结构
- 新增了 curd/docs_hub.go 文件,实现了 DocumentHub 的 TopK 方法
- 新增了 utils/data_explain/data_explain_rag.go 文件,实现了结构体到解释字符串的转换
This commit is contained in:
Havoc412 2024-11-20 19:30:11 +08:00
parent 81cd287109
commit ae7edb5e8d
14 changed files with 293 additions and 74 deletions

View File

@ -4,8 +4,7 @@ import (
"catface/app/global/consts"
"catface/app/global/errcode"
"catface/app/global/variable"
"catface/app/model"
"catface/app/model_es"
"catface/app/model_res"
"catface/app/service/nlp"
"catface/app/service/rag/curd"
"catface/app/utils/llm_factory"
@ -103,8 +102,8 @@ func (r *Rag) ChatSSE(context *gin.Context) {
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 2)
if err != nil || len(docs) == 0 {
dochub, err := curd.TopK(mode, embedding, 1)
if err != nil || dochub.Length() == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
@ -117,7 +116,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
// 3. LLM answer
go func() {
err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client)
err := nlp.ChatRAG(query, mode, dochub, ch, client)
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
@ -172,7 +171,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
// 0-2. 测试 Python 微服务是否启动
if !micro_service.TestLinkPythonService() {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
@ -183,7 +182,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
clientInfo, ercode := variable.GlmClientHub.GetOneGlmClientInfo(token, llm_factory.GlmModeKnowledgeHub)
if ercode != 0 {
variable.ZapLog.Error("GetOneGlmClient error", zap.Error(err))
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[ercode]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
@ -196,7 +195,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
embedding, ok := nlp.GetEmbedding(clientInfo.UserQuerys)
if !ok {
code := errcode.ErrPythonServierDown
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
@ -204,12 +203,12 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
}
// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
docs, err := curd.TopK(mode, embedding, 1)
if err != nil || len(docs) == 0 {
dochub, err := curd.TopK(mode, embedding, 1)
if err != nil || dochub.Length() == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", errcode.ErrMsgForUser[code]).JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send error message via WebSocket", zap.Error(err))
}
@ -219,14 +218,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
// 0. 传递参考资料的信息
docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言
docMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, dochub.Docs) // TIP 断言
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
}
// 1. 传递 token 信息; // UPDATE 临时方案
tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
tokenMsg := model_res.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
if err != nil {
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
@ -239,7 +238,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
ch := make(chan string) // TIP 建立通道。
go func() {
err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口
err := nlp.ChatRAG(query, mode, dochub, ch, clientInfo.Client) // TIP 接口
if err != nil {
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
}
@ -253,7 +252,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
return
}
// variable.ZapLog.Info("ChatKnoledgeRAG", zap.String("c", c))
err := ws.WriteMessage(websocket.TextMessage, model.CreateNlpWebSocketResult("", c).JsonMarshal())
err := ws.WriteMessage(websocket.TextMessage, model_res.CreateNlpWebSocketResult("", c).JsonMarshal())
if err != nil {
return
}
@ -264,5 +263,5 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
}
func (r *Rag) HelpDetectCat(context *gin.Context) {
// TODO 也许也可以同样上面那个接口了。
// TODO 也许也可以同样上面那个接口了。
}

View File

@ -5,7 +5,6 @@ import (
"catface/app/global/consts"
"catface/app/global/variable"
"catface/app/model"
"catface/app/service/nlp"
"catface/app/utils/data_bind"
"catface/app/utils/model_handler"
"context"
@ -33,13 +32,12 @@ func CreateEncounterESFactory(encounter *model.Encounter) *Encounter {
// INFO 存储能够作为索引存在的数据。
type Encounter struct {
Id int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Tags []string `json:"tags"`
Title string `json:"title" explain:"路遇笔记标题"`
Content string `json:"content" explain:"内容"`
Tags []string `json:"tags" explain:"标签"`
Embedding []float64 `json:"embedding"`
// TagsHighlight []string `json:"tags_highlight"` // TODO 如何 insert 时忽略query 时绑定。
TagsHighlight []string `json:"-" bind:"tags_highlight"` // TODO 如何 insert 时忽略query 时绑定。
TagsHighlight []string `json:"-" bind:"tags_highlight"`
}
func (e *Encounter) IndexName() string {
@ -49,11 +47,6 @@ func (e *Encounter) IndexName() string {
func (e *Encounter) InsertDocument() error {
ctx := context.Background()
var ok bool
if e.Embedding, ok = nlp.GetEmbedding([]string{e.Title, e.Content}); !ok {
return fmt.Errorf("nlp embedding service error")
}
// 将结构体转换为 JSON 字符串
data, err := json.Marshal(e)
if err != nil {

View File

@ -20,9 +20,9 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
type DocResult struct {
DocBase
Id int64 `json:"id"`
Name string `json:"name"`
Content string `json:"content"`
UpdatedAt *time.Time `json:"updated_at"`
Name string `json:"name" explain:"文档名称"`
Content string `json:"content" explain:"文档内容"`
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
}
// GetType implements DocInterface.

View File

@ -19,11 +19,11 @@ func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encou
type EncounterResult struct {
DocBase
Id int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
UpdatedAt *time.Time `json:"updated_at"`
Title string `json:"title" explain:"路遇笔记标题"`
Content string `json:"content" explain:"内容"`
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
}
func (e EncounterResult) ToString() string {
return fmt.Sprintf(`路遇笔记标题:%s路遇笔记内容%s`, e.Title, e.Content)
return fmt.Sprintf(`路遇笔记标题:%s路遇笔记内容%s最后更新时间:%v`, e.Title, e.Content, e.UpdatedAt)
}

View File

@ -1,4 +1,4 @@
package model
package model_res
import (
"catface/app/global/consts"

View File

@ -0,0 +1,29 @@
package curd_es
import (
"catface/app/model"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/utils/data_explain"
"fmt"
)
func CreateEncounterESCurdFactory(encounter *model.Encounter) *EncounterESCurd {
return &EncounterESCurd{
model_es.CreateEncounterESFactory(encounter),
}
}
type EncounterESCurd struct {
encounter_es *model_es.Encounter
}
func (e *EncounterESCurd) InsertDocument() error {
var ok bool
explian := data_explain.GenerateExplainStringForEmbedding(e.encounter_es)
if e.encounter_es.Embedding, ok = nlp.GetEmbeddingOneString(explian); !ok {
return fmt.Errorf("nlp embedding service error")
}
return e.encounter_es.InsertDocument()
}

View File

@ -32,3 +32,7 @@ func GetEmbedding(text []string) ([]float64, bool) {
return res.Embedding, true
}
}
func GetEmbeddingOneString(str string) ([]float64, bool) {
return GetEmbedding([]string{str})
}

View File

@ -3,6 +3,7 @@ package nlp
import (
"catface/app/global/variable"
"catface/app/service/nlp/glm"
"catface/app/service/rag/curd"
"fmt"
"strings"
@ -16,13 +17,13 @@ func GenerateTitle(content string, client *zhipu.ChatCompletionService) string {
}
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
func ChatRAG(doc, query, mode string, ch chan<- string, client *zhipu.ChatCompletionService) error {
func ChatRAG(query, mode string, dochub curd.DocumentHub, ch chan<- string, client *zhipu.ChatCompletionService) error {
// 读取配置文件中的 KnoledgeRAG 模板
promptTemplate := variable.PromptsYml.GetString("Prompt.RAG." + mode)
// 替换模板中的占位符
message := strings.Replace(promptTemplate, "{question}", query, -1)
message = strings.Replace(message, "{context}", doc, -1)
message = strings.Replace(message, "{context}", dochub.Explain4LLM(), -1)
// 调用聊天接口
// err := glm.ChatStream(message, ch)

View File

@ -1,37 +0,0 @@
package curd
import (
"catface/app/global/consts"
"catface/app/model_res"
"fmt"
)
func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) {
switch mode {
case consts.RagChatModeKnowledge:
results, err := CreateDocCurdFactory().TopK(embedding, k)
if err != nil {
return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err)
}
for _, result := range results {
temp = append(temp, result)
}
case consts.RagChatModeDiary:
results, err := CreateEncounterCurdFactory().TopK(embedding, k)
if err != nil {
return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err)
}
for _, result := range results {
temp = append(temp, result)
}
default:
if mode == "" {
err = fmt.Errorf("TopK: mode不能为空")
} else {
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
}
}
return temp, err
}

View File

@ -0,0 +1,61 @@
package curd
import (
"catface/app/global/consts"
"catface/app/utils/data_explain"
"fmt"
"strconv"
)
/**
* @description: 作为搜索到的文档的集合目前都是单一的文档类型// TODO 如何更好的支持多文档的 TopK
* 相当于 DocumentHub 的构造函数
* @param {string} mode
* @param {[]float64} embedding
* @param {int} k
* @return {*}
*/
func TopK(mode string, embedding []float64, k int) (dochub DocumentHub, err error) {
switch mode {
case consts.RagChatModeKnowledge:
results, errTemp := CreateDocCurdFactory().TopK(embedding, k)
if errTemp != nil {
err = fmt.Errorf("TopK: 获取知识库TopK失败: %w", errTemp)
}
for _, result := range results {
dochub.Docs = append(dochub.Docs, result)
}
case consts.RagChatModeDiary:
results, errTemp := CreateEncounterCurdFactory().TopK(embedding, k)
if errTemp != nil {
err = fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", errTemp)
}
for _, result := range results {
dochub.Docs = append(dochub.Docs, result)
}
default:
if mode == "" {
err = fmt.Errorf("TopK: mode不能为空")
} else {
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
}
}
return
}
type DocumentHub struct {
Docs []interface{}
}
func (d *DocumentHub) Length() int {
return len(d.Docs)
}
func (d *DocumentHub) Explain4LLM() (explain string) {
for id, doc := range d.Docs {
explain += strconv.Itoa(id) + "." + data_explain.GenerateExplainStringForEmbedding(doc) + "\n"
}
return
}

View File

@ -0,0 +1,29 @@
package data_explain
import (
"fmt"
"reflect"
"strings"
)
/**
* @description: 集成 Struct -> Explain for RAG
* @param {interface{}} v
* @return {*}
*/
func GenerateExplainStringForEmbedding(v interface{}) string {
val := reflect.ValueOf(v)
typ := val.Type()
var result []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("explain")
if tag == "" {
continue
}
value := val.Field(i).Interface()
result = append(result, fmt.Sprintf("%s%v", tag, value))
}
return strings.Join(result, "")
}

44
test/encounter_test.go Normal file
View File

@ -0,0 +1,44 @@
package test
import (
"fmt"
"testing"
"time"
)
// 假设的 EncounterResult 结构体
type EncounterResult struct {
Title string
Content string
UpdatedAt time.Time
}
// ToString 方法
func (e EncounterResult) ToString() string {
return fmt.Sprintf(`路遇笔记标题:%s路遇笔记内容%s最后更新时间%v`, e.Title, e.Content, e.UpdatedAt)
}
// 测试 EncounterResult 的 ToString 方法
func TestEncounterResult_ToString(t *testing.T) {
// 设置一个时间点,用于测试
testTime := time.Now()
// 创建一个 EncounterResult 实例
testResult := EncounterResult{
Title: "测试笔记",
Content: "这是测试笔记的内容",
UpdatedAt: testTime,
}
// 调用 ToString 方法
resultString := testResult.ToString()
t.Log("resultString:", resultString)
// 构建期望的结果字符串
expectedString := fmt.Sprintf(`路遇笔记标题:%s路遇笔记内容%s最后更新时间%v`, testResult.Title, testResult.Content, testResult.UpdatedAt)
// 比较实际结果和期望结果
if resultString != expectedString {
t.Errorf("ToString() failed, expected %q, got %q", expectedString, resultString)
}
}

46
test/explain_test.go Normal file
View File

@ -0,0 +1,46 @@
package test
import (
"fmt"
"reflect"
"strings"
"testing"
"time"
)
// EncounterResult2 结构体定义
type EncounterResult2 struct {
Id int64 `json:"id" explain:"路遇笔记ID"`
Title string `json:"title" explain:"路遇笔记标题"`
Content string `json:"content" explain:"路遇笔记内容"`
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
}
// StructToString 使用反射将结构体的内容组织为字符串
func StructToString(v interface{}) string {
val := reflect.ValueOf(v)
typ := val.Type()
var result []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("explain")
value := val.Field(i).Interface()
result = append(result, fmt.Sprintf("%s%v", tag, value))
}
return strings.Join(result, "")
}
func TestExplain(t *testing.T) {
// 示例数据
updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC)
encounter := EncounterResult2{
Id: 1,
Title: "遇见小猫",
Content: "今天在公园遇到了一只可爱的小猫。",
UpdatedAt: &updatedAt,
}
// 调用 StructToString 函数
t.Logf("结构体内容:", StructToString(encounter))
}

50
test/temp/explain.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"fmt"
"reflect"
"strings"
"time"
)
// EncounterResult 结构体定义
type EncounterResult struct {
Id int64 `json:"id"`
Title string `json:"title" explain:"路遇笔记标题"`
Content string `json:"content" explain:"路遇笔记内容"`
UpdatedAt *time.Time `json:"updated_at" explain:"最后更新时间"`
NoTag string `json:"no_tag"` // 没有 explain 标签的字段
}
// StructToString 使用反射将结构体的内容组织为字符串,忽略没有 explain 标签的字段
func StructToString(v interface{}) string {
val := reflect.ValueOf(v)
typ := val.Type()
var result []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("explain")
if tag == "" {
continue // 跳过没有 explain 标签的字段
}
value := val.Field(i).Interface()
result = append(result, fmt.Sprintf("%s%v", tag, value))
}
return strings.Join(result, "")
}
func main() {
// 示例数据
updatedAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC)
encounter := EncounterResult{
Id: 1,
Title: "遇见小猫",
Content: "今天在公园遇到了一只可爱的小猫。",
UpdatedAt: &updatedAt,
NoTag: "这个字段没有 explain 标签",
}
// 调用 StructToString 函数
fmt.Println(StructToString(encounter))
}