Havoc412 81cd287109 feat(api): 新增 RAG 聊天模式和优化 ES 查询功能
- 新增 RAG 聊天模式常量和前端字段设定
- 修改 Encounters Create 方法中的 ES 同步逻辑
- 更新 Rag ChatSSE 和 ChatWebSocket 方法,支持新的聊天模式
- 重构 NlpWebSocketResult 创建函数,使用新增的常量
- 新增 Encounter 的 TopK 方法,用于 ES 向量搜索
- 更新 DocResult 结构,实现 DocInterface 接口
- 修改 prompts.yml,增加 Diary 模式的提示模板
2024-11-20 17:32:10 +08:00

234 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model_es
import (
"bytes"
"catface/app/global/consts"
"catface/app/global/variable"
"catface/app/model"
"catface/app/service/nlp"
"catface/app/utils/data_bind"
"catface/app/utils/model_handler"
"context"
"encoding/json"
"fmt"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esapi"
)
func CreateEncounterESFactory(encounter *model.Encounter) *Encounter {
if encounter == nil { // UPDATE 这样写好丑。
return &Encounter{}
}
// 我把数值绑定到了工厂创建当中。
return &Encounter{
Id: encounter.Id,
Title: encounter.Title,
Content: encounter.Content,
Tags: encounter.TagsList, // TODO 暂时没有对此字段的查询。
}
}
// INFO 存储能够作为索引存在的数据。
type Encounter struct {
Id int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Tags []string `json:"tags"`
Embedding []float64 `json:"embedding"`
// TagsHighlight []string `json:"tags_highlight"` // TODO 如何 insert 时忽略query 时绑定。
TagsHighlight []string `json:"-" bind:"tags_highlight"` // TODO 如何 insert 时忽略query 时绑定。
}
func (e *Encounter) IndexName() string {
return "catface_encounters"
}
func (e *Encounter) InsertDocument() error {
ctx := context.Background()
var ok bool
if e.Embedding, ok = nlp.GetEmbedding([]string{e.Title, e.Content}); !ok {
return fmt.Errorf("nlp embedding service error")
}
// 将结构体转换为 JSON 字符串
data, err := json.Marshal(e)
if err != nil {
return err
}
// 创建请求
req := esapi.IndexRequest{
Index: e.IndexName(),
DocumentID: fmt.Sprintf("%d", e.Id),
Body: bytes.NewReader(data),
Refresh: "true",
}
// 发送请求
res, err := req.Do(ctx, variable.ElasticClient)
if err != nil {
return err
}
defer res.Body.Close()
if res.IsError() {
var e map[string]interface{}
if err := json.NewDecoder(res.Body).Decode(&e); err != nil {
return fmt.Errorf("error parsing the response body: %s", err)
} else {
return fmt.Errorf("[%s] %s: %s",
res.Status(),
e["error"].(map[string]interface{})["type"],
e["error"].(map[string]interface{})["reason"],
)
}
}
return nil
}
// TODO 改正,仿 Insert
func (e *Encounter) UpdateDocument(client *elasticsearch.Client, encounter *Encounter) error {
ctx := context.Background()
// 将结构体转换为 JSON 字符串
data, err := json.Marshal(map[string]interface{}{
"doc": encounter,
})
if err != nil {
return err
}
// 创建请求
req := esapi.UpdateRequest{
Index: encounter.IndexName(),
DocumentID: fmt.Sprintf("%d", encounter.Id),
Body: bytes.NewReader(data),
Refresh: "true",
}
// 发送请求
res, err := req.Do(ctx, client)
if err != nil {
return err
}
defer res.Body.Close()
if res.IsError() {
var e map[string]interface{}
if err := json.NewDecoder(res.Body).Decode(&e); err != nil {
return fmt.Errorf("error parsing the response body: %s", err)
} else {
return fmt.Errorf("[%s] %s: %s",
res.Status(),
e["error"].(map[string]interface{})["type"],
e["error"].(map[string]interface{})["reason"],
)
}
}
return nil
}
/**
* @description: 粗略地包含各种关键词匹配,
* @param {*elasticsearch.Client} client
* @param {string} query
* @return {*} 对应 Encounter 的 id然后交给 MySQL 来查询详细的信息?
*/
func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) {
body := fmt.Sprintf(`{
"size": %d,
"query": {
"bool": {
"should": [
{"match": {"tags": "%s"}},
{"match": {"content": "%s"}},
{"match": {"title": "%s"}}
]
}
},
"highlight": {
"pre_tags": ["%v"],
"post_tags": ["%v"],
"fields": {
"title": {},
"content": {
"fragment_size" : 15
},
"tags": {
"pre_tags": [""],
"post_tags": [""]
}
}
},
"_source": ["id", "title", "content", "tags"]
}`, num, query, query, query, consts.PreTags, consts.PostTags)
hits, err := model_handler.SearchRequest(body, e.IndexName())
if err != nil {
return nil, err
}
var encounters []Encounter
for _, hit := range hits {
data := model_handler.MergeSouceWithHighlight(hit.(map[string]interface{}))
var encounter Encounter
if err := data_bind.ShouldBindFormMapToModel(data, &encounter); err != nil {
continue
}
encounters = append(encounters, encounter)
}
return encounters, nil
}
func (e *Encounter) TopK(embedding []float64, k int) ([]Encounter, error) {
// 同理 Doc
params := map[string]interface{}{
"query_vector": embedding,
}
paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, err
}
body := fmt.Sprintf(`{
"size": %d,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": %s
}
}
},
"_source":["id"]
}`, k, string(paramsJSON))
hits, err := model_handler.SearchRequest(body, e.IndexName())
if err != nil {
return nil, err
}
var encounters []Encounter
for _, hit := range hits {
hitMap := hit.(map[string]interface{})
source := hitMap["_source"].(map[string]interface{})
var encounter Encounter
if err := data_bind.ShouldBindFormMapToModel(source, &encounter); err != nil {
continue
}
encounters = append(encounters, encounter)
}
return encounters, nil
}