2024-11-12 15:53:47 +08:00
|
|
|
|
package model_es
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
2024-11-14 21:00:24 +08:00
|
|
|
|
"catface/app/global/consts"
|
2024-11-12 15:53:47 +08:00
|
|
|
|
"catface/app/global/variable"
|
|
|
|
|
"catface/app/model"
|
2024-11-19 11:27:17 +08:00
|
|
|
|
"catface/app/service/nlp"
|
2024-11-14 04:26:12 +08:00
|
|
|
|
"catface/app/utils/data_bind"
|
|
|
|
|
"catface/app/utils/model_handler"
|
2024-11-12 15:53:47 +08:00
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
|
|
|
|
|
"github.com/elastic/go-elasticsearch/v8"
|
|
|
|
|
"github.com/elastic/go-elasticsearch/v8/esapi"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func CreateEncounterESFactory(encounter *model.Encounter) *Encounter {
|
2024-11-12 16:39:06 +08:00
|
|
|
|
if encounter == nil { // UPDATE 这样写好丑。
|
|
|
|
|
return &Encounter{}
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-12 15:53:47 +08:00
|
|
|
|
// 我把数值绑定到了工厂创建当中。
|
|
|
|
|
return &Encounter{
|
|
|
|
|
Id: encounter.Id,
|
|
|
|
|
Title: encounter.Title,
|
|
|
|
|
Content: encounter.Content,
|
2024-11-13 18:56:22 +08:00
|
|
|
|
Tags: encounter.TagsList, // TODO 暂时没有对此字段的查询。
|
2024-11-12 15:53:47 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// INFO 存储能够作为索引存在的数据。
|
|
|
|
|
type Encounter struct {
|
2024-11-19 11:27:17 +08:00
|
|
|
|
Id int64 `json:"id"`
|
|
|
|
|
Title string `json:"title"`
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
Tags []string `json:"tags"`
|
|
|
|
|
Embedding []float64 `json:"embedding"`
|
|
|
|
|
|
|
|
|
|
// TagsHighlight []string `json:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。
|
|
|
|
|
TagsHighlight []string `json:"-" bind:"tags_highlight"` // TODO 如何 insert 时忽略,query 时绑定。
|
2024-11-12 15:53:47 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (e *Encounter) IndexName() string {
|
|
|
|
|
return "catface_encounters"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (e *Encounter) InsertDocument() error {
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
2024-11-19 11:27:17 +08:00
|
|
|
|
var ok bool
|
|
|
|
|
if e.Embedding, ok = nlp.GetEmbedding([]string{e.Title, e.Content}); !ok {
|
|
|
|
|
return fmt.Errorf("nlp embedding service error")
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-12 15:53:47 +08:00
|
|
|
|
// 将结构体转换为 JSON 字符串
|
|
|
|
|
data, err := json.Marshal(e)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
req := esapi.IndexRequest{
|
|
|
|
|
Index: e.IndexName(),
|
|
|
|
|
DocumentID: fmt.Sprintf("%d", e.Id),
|
|
|
|
|
Body: bytes.NewReader(data),
|
|
|
|
|
Refresh: "true",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
|
|
|
|
res, err := req.Do(ctx, variable.ElasticClient)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
defer res.Body.Close()
|
|
|
|
|
|
|
|
|
|
if res.IsError() {
|
|
|
|
|
var e map[string]interface{}
|
|
|
|
|
if err := json.NewDecoder(res.Body).Decode(&e); err != nil {
|
|
|
|
|
return fmt.Errorf("error parsing the response body: %s", err)
|
|
|
|
|
} else {
|
|
|
|
|
return fmt.Errorf("[%s] %s: %s",
|
|
|
|
|
res.Status(),
|
|
|
|
|
e["error"].(map[string]interface{})["type"],
|
|
|
|
|
e["error"].(map[string]interface{})["reason"],
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-12 16:39:06 +08:00
|
|
|
|
// TODO 改正,仿 Insert
|
2024-11-12 15:53:47 +08:00
|
|
|
|
func (e *Encounter) UpdateDocument(client *elasticsearch.Client, encounter *Encounter) error {
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
|
|
|
|
// 将结构体转换为 JSON 字符串
|
|
|
|
|
data, err := json.Marshal(map[string]interface{}{
|
|
|
|
|
"doc": encounter,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
req := esapi.UpdateRequest{
|
|
|
|
|
Index: encounter.IndexName(),
|
|
|
|
|
DocumentID: fmt.Sprintf("%d", encounter.Id),
|
|
|
|
|
Body: bytes.NewReader(data),
|
|
|
|
|
Refresh: "true",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
|
|
|
|
res, err := req.Do(ctx, client)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
defer res.Body.Close()
|
|
|
|
|
|
|
|
|
|
if res.IsError() {
|
|
|
|
|
var e map[string]interface{}
|
|
|
|
|
if err := json.NewDecoder(res.Body).Decode(&e); err != nil {
|
|
|
|
|
return fmt.Errorf("error parsing the response body: %s", err)
|
|
|
|
|
} else {
|
|
|
|
|
return fmt.Errorf("[%s] %s: %s",
|
|
|
|
|
res.Status(),
|
|
|
|
|
e["error"].(map[string]interface{})["type"],
|
|
|
|
|
e["error"].(map[string]interface{})["reason"],
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @description: 粗略地包含各种关键词匹配,
|
|
|
|
|
* @param {*elasticsearch.Client} client
|
|
|
|
|
* @param {string} query
|
|
|
|
|
* @return {*} 对应 Encounter 的 id,然后交给 MySQL 来查询详细的信息?
|
|
|
|
|
*/
|
2024-11-14 04:26:12 +08:00
|
|
|
|
func (e *Encounter) QueryDocumentsMatchAll(query string, num int) ([]Encounter, error) {
|
|
|
|
|
body := fmt.Sprintf(`{
|
2024-11-20 17:32:10 +08:00
|
|
|
|
"size": %d,
|
|
|
|
|
"query": {
|
|
|
|
|
"bool": {
|
|
|
|
|
"should": [
|
|
|
|
|
{"match": {"tags": "%s"}},
|
|
|
|
|
{"match": {"content": "%s"}},
|
|
|
|
|
{"match": {"title": "%s"}}
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"highlight": {
|
|
|
|
|
"pre_tags": ["%v"],
|
|
|
|
|
"post_tags": ["%v"],
|
|
|
|
|
"fields": {
|
|
|
|
|
"title": {},
|
|
|
|
|
"content": {
|
|
|
|
|
"fragment_size" : 15
|
|
|
|
|
},
|
|
|
|
|
"tags": {
|
|
|
|
|
"pre_tags": [""],
|
|
|
|
|
"post_tags": [""]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"_source": ["id", "title", "content", "tags"]
|
|
|
|
|
}`, num, query, query, query, consts.PreTags, consts.PostTags)
|
2024-11-14 04:26:12 +08:00
|
|
|
|
|
|
|
|
|
hits, err := model_handler.SearchRequest(body, e.IndexName())
|
2024-11-12 15:53:47 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-14 04:26:12 +08:00
|
|
|
|
var encounters []Encounter
|
2024-11-12 15:53:47 +08:00
|
|
|
|
for _, hit := range hits {
|
2024-11-14 04:26:12 +08:00
|
|
|
|
data := model_handler.MergeSouceWithHighlight(hit.(map[string]interface{}))
|
2024-11-12 15:53:47 +08:00
|
|
|
|
|
2024-11-14 04:26:12 +08:00
|
|
|
|
var encounter Encounter
|
|
|
|
|
if err := data_bind.ShouldBindFormMapToModel(data, &encounter); err != nil {
|
|
|
|
|
continue
|
2024-11-14 00:39:42 +08:00
|
|
|
|
}
|
2024-11-14 04:26:12 +08:00
|
|
|
|
|
2024-11-14 00:39:42 +08:00
|
|
|
|
encounters = append(encounters, encounter)
|
|
|
|
|
}
|
2024-11-12 15:53:47 +08:00
|
|
|
|
|
2024-11-14 04:26:12 +08:00
|
|
|
|
return encounters, nil
|
2024-11-12 15:53:47 +08:00
|
|
|
|
}
|
2024-11-20 17:32:10 +08:00
|
|
|
|
|
|
|
|
|
func (e *Encounter) TopK(embedding []float64, k int) ([]Encounter, error) {
|
|
|
|
|
// 同理 Doc
|
|
|
|
|
params := map[string]interface{}{
|
|
|
|
|
"query_vector": embedding,
|
|
|
|
|
}
|
|
|
|
|
paramsJSON, err := json.Marshal(params)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
body := fmt.Sprintf(`{
|
|
|
|
|
"size": %d,
|
|
|
|
|
"query": {
|
|
|
|
|
"script_score": {
|
|
|
|
|
"query": {"match_all": {}},
|
|
|
|
|
"script": {
|
|
|
|
|
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
|
|
|
|
"params": %s
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"_source":["id"]
|
|
|
|
|
}`, k, string(paramsJSON))
|
|
|
|
|
|
|
|
|
|
hits, err := model_handler.SearchRequest(body, e.IndexName())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var encounters []Encounter
|
|
|
|
|
for _, hit := range hits {
|
|
|
|
|
hitMap := hit.(map[string]interface{})
|
|
|
|
|
source := hitMap["_source"].(map[string]interface{})
|
|
|
|
|
var encounter Encounter
|
|
|
|
|
if err := data_bind.ShouldBindFormMapToModel(source, &encounter); err != nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
encounters = append(encounters, encounter)
|
|
|
|
|
}
|
|
|
|
|
return encounters, nil
|
|
|
|
|
}
|