feat(api): 新增 RAG 聊天模式和优化 ES 查询功能
- 新增 RAG 聊天模式常量和前端字段设定 - 修改 Encounters Create 方法中的 ES 同步逻辑 - 更新 Rag ChatSSE 和 ChatWebSocket 方法,支持新的聊天模式 - 重构 NlpWebSocketResult 创建函数,使用新增的常量 - 新增 Encounter 的 TopK 方法,用于 ES 向量搜索 - 更新 DocResult 结构,实现 DocInterface 接口 - 修改 prompts.yml,增加 Diary 模式的提示模板
This commit is contained in:
		
							parent
							
								
									679d30dc7b
								
							
						
					
					
						commit
						81cd287109
					
				@ -9,3 +9,10 @@ const (
 | 
			
		||||
	RagChatModeDiary     string = "Diary"  // 查询路遇资料等
 | 
			
		||||
	RagChatModeDetect    string = "Detect" // 辅助 catface 的辨认功能;
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 前端的字段设定
 | 
			
		||||
const (
 | 
			
		||||
	AiMessageTypeText  string = "text"
 | 
			
		||||
	AiMessageTypeDoc   string = "doc"
 | 
			
		||||
	AiMessageTypeToken string = "token"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,7 @@ func (e *Encounters) Create(context *gin.Context) {
 | 
			
		||||
		go model.CreateEncounterAnimalLinkFactory("").Insert(encounter.Id, animals_id)
 | 
			
		||||
 | 
			
		||||
		// 3. ES speed // TODO 这里如何实现 不同 DB 之间的 “事务” 概念。
 | 
			
		||||
		if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 1 {
 | 
			
		||||
		if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 0 { // TEST 暂时全部数据都同步到 ES,不做 level 过滤。
 | 
			
		||||
			go model_es.CreateEncounterESFactory(&encounter).InsertDocument()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -71,6 +71,11 @@ func (r *Rag) ChatSSE(context *gin.Context) {
 | 
			
		||||
	query := context.Query("query")
 | 
			
		||||
	token := context.Query("token")
 | 
			
		||||
 | 
			
		||||
	mode := context.Query("mode")
 | 
			
		||||
	if mode == "" {
 | 
			
		||||
		mode = consts.RagChatModeKnowledge
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 0-1. 测试 python
 | 
			
		||||
	if !micro_service.TestLinkPythonService() {
 | 
			
		||||
		code := errcode.ErrPythonService
 | 
			
		||||
@ -98,7 +103,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 2. ES TopK
 | 
			
		||||
	docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
 | 
			
		||||
	docs, err := model_es.CreateDocESFactory().TopK(embedding, 2)
 | 
			
		||||
	if err != nil || len(docs) == 0 {
 | 
			
		||||
		variable.ZapLog.Error("ES TopK error", zap.Error(err))
 | 
			
		||||
 | 
			
		||||
@ -112,7 +117,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// 3. LLM answer
 | 
			
		||||
	go func() {
 | 
			
		||||
		err := nlp.ChatRAG(docs[0].Content, query, ch, client)
 | 
			
		||||
		err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
@ -199,7 +204,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
 | 
			
		||||
	docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
 | 
			
		||||
	docs, err := curd.TopK(mode, embedding, 1)
 | 
			
		||||
	if err != nil || len(docs) == 0 {
 | 
			
		||||
		variable.ZapLog.Error("ES TopK error", zap.Error(err))
 | 
			
		||||
 | 
			
		||||
@ -214,14 +219,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
 | 
			
		||||
	// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
 | 
			
		||||
	defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
 | 
			
		||||
		// 0. 传递参考资料的信息
 | 
			
		||||
		docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
 | 
			
		||||
		docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言
 | 
			
		||||
		err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 1. 传递 token 信息; // UPDATE 临时方案
 | 
			
		||||
		tokenMsg := model.CreateNlpWebSocketResult("token", token)
 | 
			
		||||
		tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
 | 
			
		||||
		err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
 | 
			
		||||
@ -234,7 +239,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
 | 
			
		||||
	ch := make(chan string)                               // TIP 建立通道。
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		err := nlp.ChatRAG(docs[0].Content, query, mode, ch, clientInfo.Client)
 | 
			
		||||
		err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,13 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "encoding/json"
 | 
			
		||||
import (
 | 
			
		||||
	"catface/app/global/consts"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
 | 
			
		||||
	if t == "" {
 | 
			
		||||
		t = "chat"
 | 
			
		||||
		t = consts.AiMessageTypeText
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &NlpWebSocketResult{
 | 
			
		||||
 | 
			
		||||
@ -142,32 +142,32 @@ func (e *Encounter) UpdateDocument(client *elasticsearch.Client, encounter *Enco
 | 
			
		||||
 */
 | 
			
		||||
func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) {
 | 
			
		||||
	body := fmt.Sprintf(`{
 | 
			
		||||
  "size": %d, 
 | 
			
		||||
  "query": {
 | 
			
		||||
    "bool": {
 | 
			
		||||
      "should": [
 | 
			
		||||
        {"match": {"tags": "%s"}},
 | 
			
		||||
        {"match": {"content": "%s"}},
 | 
			
		||||
        {"match": {"title": "%s"}}
 | 
			
		||||
      ]
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "highlight": {
 | 
			
		||||
    "pre_tags": ["%v"],
 | 
			
		||||
    "post_tags": ["%v"],
 | 
			
		||||
    "fields": {
 | 
			
		||||
      "title": {},
 | 
			
		||||
      "content": {
 | 
			
		||||
        "fragment_size" : 15
 | 
			
		||||
      },
 | 
			
		||||
      "tags": {
 | 
			
		||||
        "pre_tags": [""],
 | 
			
		||||
        "post_tags": [""]
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "_source": ["id", "title", "content", "tags"]
 | 
			
		||||
}`, num, query, query, query, consts.PreTags, consts.PostTags)
 | 
			
		||||
		"size": %d, 
 | 
			
		||||
		"query": {
 | 
			
		||||
			"bool": {
 | 
			
		||||
			"should": [
 | 
			
		||||
				{"match": {"tags": "%s"}},
 | 
			
		||||
				{"match": {"content": "%s"}},
 | 
			
		||||
				{"match": {"title": "%s"}}
 | 
			
		||||
			]
 | 
			
		||||
			}
 | 
			
		||||
		},
 | 
			
		||||
		"highlight": {
 | 
			
		||||
			"pre_tags": ["%v"],
 | 
			
		||||
			"post_tags": ["%v"],
 | 
			
		||||
			"fields": {
 | 
			
		||||
			"title": {},
 | 
			
		||||
			"content": {
 | 
			
		||||
				"fragment_size" : 15
 | 
			
		||||
			},
 | 
			
		||||
			"tags": {
 | 
			
		||||
				"pre_tags": [""],
 | 
			
		||||
				"post_tags": [""]
 | 
			
		||||
			}
 | 
			
		||||
			}
 | 
			
		||||
		},
 | 
			
		||||
		"_source": ["id", "title", "content", "tags"]
 | 
			
		||||
	}`, num, query, query, query, consts.PreTags, consts.PostTags)
 | 
			
		||||
 | 
			
		||||
	hits, err := model_handler.SearchRequest(body, e.IndexName())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@ -188,3 +188,46 @@ func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter,
 | 
			
		||||
 | 
			
		||||
	return encounters, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *Encounter) TopK(embedding []float64, k int) ([]Encounter, error) {
 | 
			
		||||
	// 同理 Doc
 | 
			
		||||
	params := map[string]interface{}{
 | 
			
		||||
		"query_vector": embedding,
 | 
			
		||||
	}
 | 
			
		||||
	paramsJSON, err := json.Marshal(params)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := fmt.Sprintf(`{
 | 
			
		||||
		"size": %d,
 | 
			
		||||
		"query": {
 | 
			
		||||
			"script_score": {
 | 
			
		||||
				"query": {"match_all": {}},
 | 
			
		||||
				"script": {
 | 
			
		||||
					"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
 | 
			
		||||
					"params": %s
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		},
 | 
			
		||||
		"_source":["id"]
 | 
			
		||||
	}`, k, string(paramsJSON))
 | 
			
		||||
 | 
			
		||||
	hits, err := model_handler.SearchRequest(body, e.IndexName())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var encounters []Encounter
 | 
			
		||||
	for _, hit := range hits {
 | 
			
		||||
		hitMap := hit.(map[string]interface{})
 | 
			
		||||
		source := hitMap["_source"].(map[string]interface{})
 | 
			
		||||
		var encounter Encounter
 | 
			
		||||
		if err := data_bind.ShouldBindFormMapToModel(source, &encounter); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		encounters = append(encounters, encounter)
 | 
			
		||||
	}
 | 
			
		||||
	return encounters, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								app/model_res/base_model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								app/model_res/base_model.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
package model_res
 | 
			
		||||
 | 
			
		||||
type DocInterface interface {
 | 
			
		||||
	ToString() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DocBase struct {
 | 
			
		||||
	Type string `json:"type"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d DocBase) ToString() string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@ -6,10 +6,10 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// BUG 存在 依賴循環
 | 
			
		||||
// INFO 由于直接放到 model 中会导致循环引用,所以放到 model_res 中
 | 
			
		||||
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
 | 
			
		||||
	return &DocResult{
 | 
			
		||||
		Type:      "doc",
 | 
			
		||||
		DocBase:   DocBase{Type: "doc"},
 | 
			
		||||
		Id:        doc.Id,
 | 
			
		||||
		Name:      doc.Name,
 | 
			
		||||
		Content:   doc_es.Content,
 | 
			
		||||
@ -18,9 +18,22 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DocResult struct {
 | 
			
		||||
	Type      string     `json:"type"`
 | 
			
		||||
	DocBase
 | 
			
		||||
	Id        int64      `json:"id"`
 | 
			
		||||
	Name      string     `json:"name"`
 | 
			
		||||
	Content   string     `json:"content"`
 | 
			
		||||
	UpdatedAt *time.Time `json:"updated_at"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetType implements DocInterface.
 | 
			
		||||
func (d DocResult) GetType() string {
 | 
			
		||||
	panic("unimplemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @description: 实现 DocInterface 接口,输出作为 LLM 的参考内容。
 | 
			
		||||
 * @return {*}
 | 
			
		||||
 */
 | 
			
		||||
func (d DocResult) ToString() string {
 | 
			
		||||
	return d.Content
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								app/model_res/encounter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								app/model_res/encounter.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
package model_res
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"catface/app/model"
 | 
			
		||||
	"catface/app/model_es"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encounter) *EncounterResult {
 | 
			
		||||
	return &EncounterResult{
 | 
			
		||||
		DocBase:   DocBase{Type: "encounter"},
 | 
			
		||||
		Id:        encounter.Id,
 | 
			
		||||
		Title:     encounter.Title,
 | 
			
		||||
		Content:   encounter.Content,
 | 
			
		||||
		UpdatedAt: encounter.UpdatedAt}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EncounterResult struct {
 | 
			
		||||
	DocBase
 | 
			
		||||
	Id        int64      `json:"id"`
 | 
			
		||||
	Title     string     `json:"title"`
 | 
			
		||||
	Content   string     `json:"content"`
 | 
			
		||||
	UpdatedAt *time.Time `json:"updated_at"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e EncounterResult) ToString() string {
 | 
			
		||||
	return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;`, e.Title, e.Content)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										37
									
								
								app/service/rag/curd/base_curd.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								app/service/rag/curd/base_curd.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
package curd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"catface/app/global/consts"
 | 
			
		||||
	"catface/app/model_res"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) {
 | 
			
		||||
	switch mode {
 | 
			
		||||
	case consts.RagChatModeKnowledge:
 | 
			
		||||
		results, err := CreateDocCurdFactory().TopK(embedding, k)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		for _, result := range results {
 | 
			
		||||
			temp = append(temp, result)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case consts.RagChatModeDiary:
 | 
			
		||||
		results, err := CreateEncounterCurdFactory().TopK(embedding, k)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		for _, result := range results {
 | 
			
		||||
			temp = append(temp, result)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		if mode == "" {
 | 
			
		||||
			err = fmt.Errorf("TopK: mode不能为空")
 | 
			
		||||
		} else {
 | 
			
		||||
			err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return temp, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								app/service/rag/curd/encounter_curd.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								app/service/rag/curd/encounter_curd.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
package curd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"catface/app/model"
 | 
			
		||||
	"catface/app/model_es"
 | 
			
		||||
	"catface/app/model_res"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CreateEncounterCurdFactory() *EncounterCurd {
 | 
			
		||||
	return &EncounterCurd{
 | 
			
		||||
		enc:    model.CreateEncounterFactory(""),
 | 
			
		||||
		enc_es: model_es.CreateEncounterESFactory(nil),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EncounterCurd struct {
 | 
			
		||||
	enc    *model.Encounter
 | 
			
		||||
	enc_es *model_es.Encounter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *EncounterCurd) TopK(embedding []float64, k int) (temp []model_res.EncounterResult, err error) {
 | 
			
		||||
	// ES: TopK
 | 
			
		||||
	encounters_es, err := e.enc_es.TopK(embedding, k)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// MySQL 补充信息
 | 
			
		||||
	var ids []int64
 | 
			
		||||
	for _, encounter := range encounters_es {
 | 
			
		||||
		ids = append(ids, encounter.Id)
 | 
			
		||||
	}
 | 
			
		||||
	encounters := e.enc.ShowByIDs(ids, "id", "title", "content", "updated_at")
 | 
			
		||||
	for _, encounter := range encounters {
 | 
			
		||||
		for _, encounter_es := range encounters_es {
 | 
			
		||||
			if encounter.Id == encounter_es.Id {
 | 
			
		||||
				temp = append(temp, *model_res.NewEncounterResult(&encounter, &encounter_es))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -22,4 +22,11 @@ Prompt:
 | 
			
		||||
      ···
 | 
			
		||||
      {context}
 | 
			
		||||
      ···
 | 
			
		||||
      如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。" 
 | 
			
		||||
      如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
 | 
			
		||||
    Diary: "使用以知识库中找到的猫猫路遇日记来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
 | 
			
		||||
      问题: {question}
 | 
			
		||||
      可参考的路遇日记:
 | 
			
		||||
      ···
 | 
			
		||||
      {context}
 | 
			
		||||
      ···
 | 
			
		||||
      如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user