feat(api): 新增 RAG 聊天模式和优化 ES 查询功能
- 新增 RAG 聊天模式常量和前端字段设定 - 修改 Encounters Create 方法中的 ES 同步逻辑 - 更新 Rag ChatSSE 和 ChatWebSocket 方法,支持新的聊天模式 - 重构 NlpWebSocketResult 创建函数,使用新增的常量 - 新增 Encounter 的 TopK 方法,用于 ES 向量搜索 - 更新 DocResult 结构,实现 DocInterface 接口 - 修改 prompts.yml,增加 Diary 模式的提示模板
This commit is contained in:
parent
679d30dc7b
commit
81cd287109
@ -9,3 +9,10 @@ const (
|
|||||||
RagChatModeDiary string = "Diary" // 查询路遇资料等
|
RagChatModeDiary string = "Diary" // 查询路遇资料等
|
||||||
RagChatModeDetect string = "Detect" // 辅助 catface 的辨认功能;
|
RagChatModeDetect string = "Detect" // 辅助 catface 的辨认功能;
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 前端的字段设定
|
||||||
|
const (
|
||||||
|
AiMessageTypeText string = "text"
|
||||||
|
AiMessageTypeDoc string = "doc"
|
||||||
|
AiMessageTypeToken string = "token"
|
||||||
|
)
|
||||||
|
@ -74,7 +74,7 @@ func (e *Encounters) Create(context *gin.Context) {
|
|||||||
go model.CreateEncounterAnimalLinkFactory("").Insert(encounter.Id, animals_id)
|
go model.CreateEncounterAnimalLinkFactory("").Insert(encounter.Id, animals_id)
|
||||||
|
|
||||||
// 3. ES speed // TODO 这里如何实现 不同 DB 之间的 “事务” 概念。
|
// 3. ES speed // TODO 这里如何实现 不同 DB 之间的 “事务” 概念。
|
||||||
if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 1 {
|
if level := int(context.GetFloat64(consts.ValidatorPrefix + "level")); level > 0 { // TEST 暂时全部数据都同步到 ES,不做 level 过滤。
|
||||||
go model_es.CreateEncounterESFactory(&encounter).InsertDocument()
|
go model_es.CreateEncounterESFactory(&encounter).InsertDocument()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,6 +71,11 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
query := context.Query("query")
|
query := context.Query("query")
|
||||||
token := context.Query("token")
|
token := context.Query("token")
|
||||||
|
|
||||||
|
mode := context.Query("mode")
|
||||||
|
if mode == "" {
|
||||||
|
mode = consts.RagChatModeKnowledge
|
||||||
|
}
|
||||||
|
|
||||||
// 0-1. 测试 python
|
// 0-1. 测试 python
|
||||||
if !micro_service.TestLinkPythonService() {
|
if !micro_service.TestLinkPythonService() {
|
||||||
code := errcode.ErrPythonService
|
code := errcode.ErrPythonService
|
||||||
@ -98,7 +103,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. ES TopK
|
// 2. ES TopK
|
||||||
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
docs, err := model_es.CreateDocESFactory().TopK(embedding, 2)
|
||||||
if err != nil || len(docs) == 0 {
|
if err != nil || len(docs) == 0 {
|
||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
@ -112,7 +117,7 @@ func (r *Rag) ChatSSE(context *gin.Context) {
|
|||||||
|
|
||||||
// 3. LLM answer
|
// 3. LLM answer
|
||||||
go func() {
|
go func() {
|
||||||
err := nlp.ChatRAG(docs[0].Content, query, ch, client)
|
err := nlp.ChatRAG(docs[0].Content, query, mode, ch, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -199,7 +204,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
// 2. ES TopK // INFO 这里需要特化选取不同知识库的文档;目前是依靠显式的路由。
|
||||||
docs, err := curd.CreateDocCurdFactory().TopK(embedding, 1)
|
docs, err := curd.TopK(mode, embedding, 1)
|
||||||
if err != nil || len(docs) == 0 {
|
if err != nil || len(docs) == 0 {
|
||||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||||
|
|
||||||
@ -214,14 +219,14 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
// STAGE websocket 的 defer 关闭函数,但是需要 ES 拿到的 doc—id
|
||||||
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
defer func() { // UPDATE 临时"持久化"方案,之后考虑结合 jwt 维护的 token 处理。
|
||||||
// 0. 传递参考资料的信息
|
// 0. 传递参考资料的信息
|
||||||
docMsg := model.CreateNlpWebSocketResult(docs[0].Type, docs)
|
docMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeDoc, docs) // TIP 断言
|
||||||
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
err := ws.WriteMessage(websocket.TextMessage, docMsg.JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send doc message via WebSocket", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 传递 token 信息; // UPDATE 临时方案
|
// 1. 传递 token 信息; // UPDATE 临时方案
|
||||||
tokenMsg := model.CreateNlpWebSocketResult("token", token)
|
tokenMsg := model.CreateNlpWebSocketResult(consts.AiMessageTypeToken, token)
|
||||||
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
err = ws.WriteMessage(websocket.TextMessage, tokenMsg.JsonMarshal())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
variable.ZapLog.Error("Failed to send token message via WebSocket", zap.Error(err))
|
||||||
@ -234,7 +239,7 @@ func (r *Rag) ChatWebSocket(context *gin.Context) {
|
|||||||
ch := make(chan string) // TIP 建立通道。
|
ch := make(chan string) // TIP 建立通道。
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := nlp.ChatRAG(docs[0].Content, query, mode, ch, clientInfo.Client)
|
err := nlp.ChatRAG(docs[0].ToString(), query, mode, ch, clientInfo.Client) // TIP 接口
|
||||||
if err != nil {
|
if err != nil {
|
||||||
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
variable.ZapLog.Error("ChatKnoledgeRAG error", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"catface/app/global/consts"
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
|
func CreateNlpWebSocketResult(t string, data any) *NlpWebSocketResult {
|
||||||
if t == "" {
|
if t == "" {
|
||||||
t = "chat"
|
t = consts.AiMessageTypeText
|
||||||
}
|
}
|
||||||
|
|
||||||
return &NlpWebSocketResult{
|
return &NlpWebSocketResult{
|
||||||
|
@ -142,32 +142,32 @@ func (e *Encounter) UpdateDocument(client *elasticsearch.Client, encounter *Enco
|
|||||||
*/
|
*/
|
||||||
func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) {
|
func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) {
|
||||||
body := fmt.Sprintf(`{
|
body := fmt.Sprintf(`{
|
||||||
"size": %d,
|
"size": %d,
|
||||||
"query": {
|
"query": {
|
||||||
"bool": {
|
"bool": {
|
||||||
"should": [
|
"should": [
|
||||||
{"match": {"tags": "%s"}},
|
{"match": {"tags": "%s"}},
|
||||||
{"match": {"content": "%s"}},
|
{"match": {"content": "%s"}},
|
||||||
{"match": {"title": "%s"}}
|
{"match": {"title": "%s"}}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"highlight": {
|
"highlight": {
|
||||||
"pre_tags": ["%v"],
|
"pre_tags": ["%v"],
|
||||||
"post_tags": ["%v"],
|
"post_tags": ["%v"],
|
||||||
"fields": {
|
"fields": {
|
||||||
"title": {},
|
"title": {},
|
||||||
"content": {
|
"content": {
|
||||||
"fragment_size" : 15
|
"fragment_size" : 15
|
||||||
},
|
},
|
||||||
"tags": {
|
"tags": {
|
||||||
"pre_tags": [""],
|
"pre_tags": [""],
|
||||||
"post_tags": [""]
|
"post_tags": [""]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"_source": ["id", "title", "content", "tags"]
|
"_source": ["id", "title", "content", "tags"]
|
||||||
}`, num, query, query, query, consts.PreTags, consts.PostTags)
|
}`, num, query, query, query, consts.PreTags, consts.PostTags)
|
||||||
|
|
||||||
hits, err := model_handler.SearchRequest(body, e.IndexName())
|
hits, err := model_handler.SearchRequest(body, e.IndexName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -188,3 +188,46 @@ func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter,
|
|||||||
|
|
||||||
return encounters, nil
|
return encounters, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Encounter) TopK(embedding []float64, k int) ([]Encounter, error) {
|
||||||
|
// 同理 Doc
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"query_vector": embedding,
|
||||||
|
}
|
||||||
|
paramsJSON, err := json.Marshal(params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body := fmt.Sprintf(`{
|
||||||
|
"size": %d,
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||||
|
"params": %s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_source":["id"]
|
||||||
|
}`, k, string(paramsJSON))
|
||||||
|
|
||||||
|
hits, err := model_handler.SearchRequest(body, e.IndexName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var encounters []Encounter
|
||||||
|
for _, hit := range hits {
|
||||||
|
hitMap := hit.(map[string]interface{})
|
||||||
|
source := hitMap["_source"].(map[string]interface{})
|
||||||
|
var encounter Encounter
|
||||||
|
if err := data_bind.ShouldBindFormMapToModel(source, &encounter); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
encounters = append(encounters, encounter)
|
||||||
|
}
|
||||||
|
return encounters, nil
|
||||||
|
}
|
||||||
|
13
app/model_res/base_model.go
Normal file
13
app/model_res/base_model.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package model_res
|
||||||
|
|
||||||
|
type DocInterface interface {
|
||||||
|
ToString() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DocBase struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DocBase) ToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
@ -6,10 +6,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BUG 存在 依賴循環
|
// INFO 由于直接放到 model 中会导致循环引用,所以放到 model_res 中
|
||||||
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
||||||
return &DocResult{
|
return &DocResult{
|
||||||
Type: "doc",
|
DocBase: DocBase{Type: "doc"},
|
||||||
Id: doc.Id,
|
Id: doc.Id,
|
||||||
Name: doc.Name,
|
Name: doc.Name,
|
||||||
Content: doc_es.Content,
|
Content: doc_es.Content,
|
||||||
@ -18,9 +18,22 @@ func NewDocResult(doc *model.Doc, doc_es *model_es.Doc) *DocResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DocResult struct {
|
type DocResult struct {
|
||||||
Type string `json:"type"`
|
DocBase
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
UpdatedAt *time.Time `json:"updated_at"`
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetType implements DocInterface.
|
||||||
|
func (d DocResult) GetType() string {
|
||||||
|
panic("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description: 实现 DocInterface 接口,输出作为 LLM 的参考内容。
|
||||||
|
* @return {*}
|
||||||
|
*/
|
||||||
|
func (d DocResult) ToString() string {
|
||||||
|
return d.Content
|
||||||
|
}
|
||||||
|
29
app/model_res/encounter.go
Normal file
29
app/model_res/encounter.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package model_res
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/model"
|
||||||
|
"catface/app/model_es"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewEncounterResult(encounter *model.Encounter, encounter_es *model_es.Encounter) *EncounterResult {
|
||||||
|
return &EncounterResult{
|
||||||
|
DocBase: DocBase{Type: "encounter"},
|
||||||
|
Id: encounter.Id,
|
||||||
|
Title: encounter.Title,
|
||||||
|
Content: encounter.Content,
|
||||||
|
UpdatedAt: encounter.UpdatedAt}
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncounterResult struct {
|
||||||
|
DocBase
|
||||||
|
Id int64 `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e EncounterResult) ToString() string {
|
||||||
|
return fmt.Sprintf(`路遇笔记标题:%s;路遇笔记内容:%s;`, e.Title, e.Content)
|
||||||
|
}
|
37
app/service/rag/curd/base_curd.go
Normal file
37
app/service/rag/curd/base_curd.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package curd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/global/consts"
|
||||||
|
"catface/app/model_res"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TopK(mode string, embedding []float64, k int) (temp []model_res.DocInterface, err error) {
|
||||||
|
switch mode {
|
||||||
|
case consts.RagChatModeKnowledge:
|
||||||
|
results, err := CreateDocCurdFactory().TopK(embedding, k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("TopK: 获取知识库TopK失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
temp = append(temp, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
case consts.RagChatModeDiary:
|
||||||
|
results, err := CreateEncounterCurdFactory().TopK(embedding, k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("TopK: 获取路遇笔记TopK失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
temp = append(temp, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if mode == "" {
|
||||||
|
err = fmt.Errorf("TopK: mode不能为空")
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("TopK: 不支持的mode: %s", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return temp, err
|
||||||
|
}
|
42
app/service/rag/curd/encounter_curd.go
Normal file
42
app/service/rag/curd/encounter_curd.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package curd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"catface/app/model"
|
||||||
|
"catface/app/model_es"
|
||||||
|
"catface/app/model_res"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateEncounterCurdFactory() *EncounterCurd {
|
||||||
|
return &EncounterCurd{
|
||||||
|
enc: model.CreateEncounterFactory(""),
|
||||||
|
enc_es: model_es.CreateEncounterESFactory(nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncounterCurd struct {
|
||||||
|
enc *model.Encounter
|
||||||
|
enc_es *model_es.Encounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EncounterCurd) TopK(embedding []float64, k int) (temp []model_res.EncounterResult, err error) {
|
||||||
|
// ES: TopK
|
||||||
|
encounters_es, err := e.enc_es.TopK(embedding, k)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MySQL 补充信息
|
||||||
|
var ids []int64
|
||||||
|
for _, encounter := range encounters_es {
|
||||||
|
ids = append(ids, encounter.Id)
|
||||||
|
}
|
||||||
|
encounters := e.enc.ShowByIDs(ids, "id", "title", "content", "updated_at")
|
||||||
|
for _, encounter := range encounters {
|
||||||
|
for _, encounter_es := range encounters_es {
|
||||||
|
if encounter.Id == encounter_es.Id {
|
||||||
|
temp = append(temp, *model_res.NewEncounterResult(&encounter, &encounter_es))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -22,4 +22,11 @@ Prompt:
|
|||||||
···
|
···
|
||||||
{context}
|
{context}
|
||||||
···
|
···
|
||||||
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
||||||
|
Diary: "使用以知识库中找到的猫猫路遇日记来回答用户的问题,如果无法回答,请回答知识库中未找到符合的资料,我不知道。
|
||||||
|
问题: {question}
|
||||||
|
可参考的路遇日记:
|
||||||
|
···
|
||||||
|
{context}
|
||||||
|
···
|
||||||
|
如果给定的知识库无法让你做出回答,请回答知识库中未找到符合的资料,我不知道。"
|
Loading…
x
Reference in New Issue
Block a user