diff --git a/app/global/errcode/code.go b/app/global/errcode/code.go index da314fe..522463f 100644 --- a/app/global/errcode/code.go +++ b/app/global/errcode/code.go @@ -7,6 +7,7 @@ const ( ErrEncounter ErrNlp ErrKnowledge + ErrSubService ) const ( diff --git a/app/global/errcode/msg.go b/app/global/errcode/msg.go index f2d2bdb..be49f11 100644 --- a/app/global/errcode/msg.go +++ b/app/global/errcode/msg.go @@ -14,12 +14,14 @@ func init() { EnocunterMsgInit(ErrMsg) NlpMsgInit(ErrMsg) KnowledgeMsgInit(ErrMsg) + SubServiceMsgInit(ErrMsg) // INGO ErrMsgForUser = make(msg) AnimalMsgUserInit(ErrMsgForUser) EncounterMsgUserInit(ErrMsgForUser) KnowledgeMsgUserInit(ErrMsgForUser) + NlpMsgUserInit(ErrMsgForUser) } func GeneralMsgInit(m msg) { diff --git a/app/global/errcode/nlp.go b/app/global/errcode/nlp.go index 8735779..4005075 100644 --- a/app/global/errcode/nlp.go +++ b/app/global/errcode/nlp.go @@ -2,8 +2,15 @@ package errcode const ( ErrNoContent = ErrNlp + iota + ErrNoDocFound ) func NlpMsgInit(m msg) { m[ErrNoContent] = "内容为空" + m[ErrNoDocFound] = "没有找到相关文档" +} + +func NlpMsgUserInit(m msg) { + m[ErrNoContent] = "请输入内容" + m[ErrNoDocFound] = "没有找到相关文档" } diff --git a/app/global/errcode/subService.go b/app/global/errcode/subService.go new file mode 100644 index 0000000..023b6d8 --- /dev/null +++ b/app/global/errcode/subService.go @@ -0,0 +1,9 @@ +package errcode + +const ( + ErrPythonService = ErrSubService + iota +) + +func SubServiceMsgInit(m msg) { + m[ErrPythonService] = "python微服务异常" +} diff --git a/app/http/controller/web/nlp_controller.go b/app/http/controller/web/nlp_controller.go index 7714520..9d5f90f 100644 --- a/app/http/controller/web/nlp_controller.go +++ b/app/http/controller/web/nlp_controller.go @@ -2,7 +2,7 @@ package web import ( "catface/app/global/consts" - "catface/app/utils/nlp" + "catface/app/service/nlp" "catface/app/utils/response" "github.com/gin-gonic/gin" diff --git a/app/http/controller/web/rag_controller.go b/app/http/controller/web/rag_controller.go new file mode 100644 index 0000000..065e7a8 --- /dev/null +++ b/app/http/controller/web/rag_controller.go @@ -0,0 +1,49 @@ +package web + +import ( + "catface/app/global/consts" + "catface/app/global/errcode" + "catface/app/global/variable" + "catface/app/model_es" + "catface/app/service/nlp" + "catface/app/utils/response" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type Rag struct { +} + +func (r *Rag) Chat(context *gin.Context) { + // 1. query embedding + query := context.GetString(consts.ValidatorPrefix + "query") + embedding, ok := nlp.GetEmbedding(query) + if !ok { + code := errcode.ErrPythonService + response.Fail(context, code, errcode.ErrMsg[code], "") + return + } + + // 2. ES TopK + docs, err := model_es.CreateDocESFactory().TopK(embedding, 1) + if err != nil || len(docs) == 0 { + variable.ZapLog.Error("ES TopK error", zap.Error(err)) + + code := errcode.ErrNoDocFound + response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code]) + } + + // 3. LLM answer + if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil { + response.Success(context, consts.CurdStatusOkMsg, gin.H{ + "answer": answer, + }) + } else { + response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "") + } +} + +func (r *Rag) HelpDetectCat(context *gin.Context) { + // TODO +} diff --git a/app/http/validator/common/register_validator/web_register_validator.go b/app/http/validator/common/register_validator/web_register_validator.go index cebad60..ba8a828 100644 --- a/app/http/validator/common/register_validator/web_register_validator.go +++ b/app/http/validator/common/register_validator/web_register_validator.go @@ -12,6 +12,7 @@ import ( "catface/app/http/validator/web/encounter_like" "catface/app/http/validator/web/knowledge" "catface/app/http/validator/web/nlp" + "catface/app/http/validator/web/rag" "catface/app/http/validator/web/search" "catface/app/http/validator/web/users" ) @@ -85,6 +86,10 @@ func WebRegisterValidator() { key = consts.ValidatorPrefix + "NlpTitle" containers.Set(key, nlp.Title{}) + // TAG RAG + key = consts.ValidatorPrefix + "RagDefaultChat" + containers.Set(key, rag.Chat{}) + // TAG Search key = consts.ValidatorPrefix + "SearchAll" containers.Set(key, search.SearchAll{}) diff --git a/app/http/validator/web/rag/chat.go b/app/http/validator/web/rag/chat.go new file mode 100644 index 0000000..2b9adbd --- /dev/null +++ b/app/http/validator/web/rag/chat.go @@ -0,0 +1,31 @@ +package rag + +import ( + "catface/app/global/consts" + "catface/app/http/controller/web" + "catface/app/http/validator/core/data_transfer" + "catface/app/utils/response" + + "github.com/gin-gonic/gin" +) + +// INFO 虽然起名为 Chat,但是默认就会去查询 知识库,也就是不作为一般的 LLM-chat 来使用。 +type Chat struct { + Query string `form:"query" json:"query" binding:"required"` + // TODO 这里还需要处理一下历史记录? +} + +func (c Chat) CheckParams(context *gin.Context) { + if err := context.ShouldBind(&c); err != nil { + response.ValidatorError(context, err) + return + } + + extraAddBindDataContext := data_transfer.DataAddContext(c, consts.ValidatorPrefix, context) + if extraAddBindDataContext == nil { + response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "") + } else { + (&web.Rag{}).Chat(extraAddBindDataContext) + } + +} diff --git a/app/model_es/doc.go b/app/model_es/doc.go index 058d92e..4f92693 100644 --- a/app/model_es/doc.go +++ b/app/model_es/doc.go @@ -1,5 +1,12 @@ package model_es +import ( + "catface/app/utils/data_bind" + "catface/app/utils/model_handler" + "encoding/json" + "fmt" +) + // INFO @brief 这个文件就是处理 ES 存储文档特征向量的集中处理; func CreateDocESFactory() *Doc { @@ -16,6 +23,51 @@ func (d *Doc) IndexName() string { return "catface_docs" } +func (d *Doc) TopK(embedding []float64, k int) ([]Doc, error) { + // 将 embedding 数组转换为 JSON 格式 + params := map[string]interface{}{ + "query_vector": embedding, + } + paramsJSON, err := json.Marshal(params) + if err != nil { + return nil, err + } + + // 构建请求体 + body := fmt.Sprintf(`{ + "size": %d, + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": %s + } + } + }, + "_source": ["content"] + }`, k, string(paramsJSON)) + + hits, err := model_handler.SearchRequest(body, d.IndexName()) + if err != nil { + return nil, err + } + + var docs []Doc + for _, hit := range hits { + hitMap := hit.(map[string]interface{}) + source := hitMap["_source"].(map[string]interface{}) + var doc Doc + if err := data_bind.ShouldBindFormMapToModel(source, &doc); err != nil { + continue + } + + docs = append(docs, doc) + } + + return docs, nil +} + // UPDATE 因为 chunck 还是 Python 来处理会比较方便,所以 Go 这边主要还是处理查询相关的操作。 // func (d *Doc) InsertDocument() error { // return nil diff --git a/app/service/nlp/embedding.go b/app/service/nlp/embedding.go new file mode 100644 index 0000000..5922243 --- /dev/null +++ b/app/service/nlp/embedding.go @@ -0,0 +1,34 @@ +package nlp + +import ( + "catface/app/global/variable" + "catface/app/utils/micro_service" + "context" + + "github.com/carlmjohnson/requests" +) + +type EmbeddingRes struct { + Status int `json:"status"` + Message string `json:"message"` + Embedding []float64 `json:"embedding"` +} + +func GetEmbedding(text string) ([]float64, bool) { + body := map[string]interface{}{ + "text": text, + } + var res EmbeddingRes + err := requests.URL(micro_service.FetchPythonServiceUrl("rag/bge_embedding")). + BodyJSON(&body). + ToJSON(&res). + Fetch(context.Background()) + if err != nil { + variable.ZapLog.Error("获取嵌入向量失败: " + err.Error()) + } + if res.Status != 200 { + return nil, false + } else { + return res.Embedding, true + } +} diff --git a/app/service/nlp/func.go b/app/service/nlp/func.go new file mode 100644 index 0000000..6be43c1 --- /dev/null +++ b/app/service/nlp/func.go @@ -0,0 +1,32 @@ +package nlp + +import ( + "catface/app/global/variable" + "catface/app/service/nlp/glm" + "fmt" + "strings" +) + +func GenerateTitle(content string) string { + message := variable.PromptsYml.GetString("Prompt.Title") + content + title, _ := glm.Chat(message) + return title +} + +// ChatKnoledgeRAG 使用 RAG 模型进行知识问答 +func ChatKnoledgeRAG(doc, query string) (string, error) { + // 读取配置文件中的 KnoledgeRAG 模板 + promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG") + + // 替换模板中的占位符 + message := strings.Replace(promptTemplate, "{question}", query, -1) + message = strings.Replace(message, "{context}", doc, -1) + + // 调用聊天接口 + response, err := glm.Chat(message) + if err != nil { + return "", fmt.Errorf("调用聊天接口失败: %w", err) + } + + return response, nil +} diff --git a/app/utils/nlp/glm/glm.go b/app/service/nlp/glm/glm.go similarity index 100% rename from app/utils/nlp/glm/glm.go rename to app/service/nlp/glm/glm.go diff --git a/app/utils/oss/constant.go b/app/service/oss/constant.go similarity index 100% rename from app/utils/oss/constant.go rename to app/service/oss/constant.go diff --git a/app/utils/oss/qi_niu.go b/app/service/oss/qi_niu.go similarity index 100% rename from app/utils/oss/qi_niu.go rename to app/service/oss/qi_niu.go diff --git a/app/utils/oss/qi_niu_test.go b/app/service/oss/qi_niu_test.go similarity index 100% rename from app/utils/oss/qi_niu_test.go rename to app/service/oss/qi_niu_test.go diff --git a/app/utils/micro_service/micro_service.go b/app/utils/micro_service/micro_service.go new file mode 100644 index 0000000..187c176 --- /dev/null +++ b/app/utils/micro_service/micro_service.go @@ -0,0 +1,20 @@ +package micro_service + +import ( + "catface/app/global/variable" + "fmt" + "strings" +) + +func FetchPythonServiceUrl(url string) string { + // 检查 url 是否以 / 开头,如果是则去掉开头的 / + if strings.HasPrefix(url, "/") { + url = url[1:] + } + + return fmt.Sprintf(`http://%s:%v/%s/%s`, + variable.ConfigYml.GetString("PythonService.Host"), + variable.ConfigYml.GetString("PythonService.Port"), + variable.ConfigYml.GetString("PythonService.TopUrl"), + url) +} diff --git a/app/utils/nlp/func.go b/app/utils/nlp/func.go deleted file mode 100644 index 753e846..0000000 --- a/app/utils/nlp/func.go +++ /dev/null @@ -1,12 +0,0 @@ -package nlp - -import ( - "catface/app/global/variable" - "catface/app/utils/nlp/glm" -) - -func GenerateTitle(content string) string { - message := variable.PromptsYml.GetString("Prompt.Title") + content - title, _ := glm.Chat(message) - return title -} diff --git a/config/config.yml b/config/config.yml index 5bfacdd..cdc6d57 100644 --- a/config/config.yml +++ b/config/config.yml @@ -163,7 +163,6 @@ Glm: ApiKey: "0cf510ebc01599dba2a593069c1bdfbc.nQBQ4skP8xBh7ijU" DefaultModel: "glm-4-flash" - # qiNiu 云存储配置 QiNiu: AccessKey: "bI1MpHUBA9OCg4uSJkuJRmScuCJfOlbePe8fCENo" @@ -173,4 +172,9 @@ ElasticSearch: Addr: "http://localhost:9200" UserName: "elastic" Password: "" - + +PythonService: + Host: "localhost" + Port: 8000 + TopUrl: "api" + # HttpOkCode: 200 # 这个就还是直接硬编码了... diff --git a/config/prompts.yml b/config/prompts.yml index 4c16188..94ca367 100644 --- a/config/prompts.yml +++ b/config/prompts.yml @@ -1,2 +1,9 @@ Prompt: - Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:" \ No newline at end of file + Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:" + KnoledgeRAG: "使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 + 问题: {question} + 可参考的上下文: + ··· + {context} + ··· + 如果给定的上下文无法让你做出回答,请回答知识库中没有这个内容,你不知道。" \ No newline at end of file diff --git a/go.mod b/go.mod index df594d3..554df46 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect github.com/bytedance/sonic v1.12.3 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect + github.com/carlmjohnson/requests v0.24.2 github.com/casbin/casbin/v2 v2.100.0 // indirects github.com/casbin/gorm-adapter/v3 v3.28.0 github.com/casbin/govaluate v1.2.0 // indirect diff --git a/go.sum b/go.sum index 1e3ecae..e99efea 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKz github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/carlmjohnson/requests v0.24.2 h1:JDakhAmTIKL/qL/1P7Kkc2INGBJIkIFP6xUeUmPzLso= +github.com/carlmjohnson/requests v0.24.2/go.mod h1:duYA/jDnyZ6f3xbcF5PpZ9N8clgopubP2nK5i6MVMhU= github.com/casbin/casbin/v2 v2.100.0 h1:aeugSNjjHfCrgA22nHkVvw2xsscboHv5r0a13ljQKGQ= github.com/casbin/casbin/v2 v2.100.0/go.mod h1:LO7YPez4dX3LgoTCqSQAleQDo0S0BeZBDxYnPUl95Ng= github.com/casbin/gorm-adapter/v3 v3.28.0 h1:ORF8prF6SfaipdgT1fud+r1Tp5J0uul8QaKJHqCPY/o= diff --git a/routers/web.go b/routers/web.go index b57195a..2cff193 100644 --- a/routers/web.go +++ b/routers/web.go @@ -150,6 +150,11 @@ func InitWebRouter() *gin.Engine { nlp.POST("title", validatorFactory.Create(consts.ValidatorPrefix+"NlpTitle")) } + rag := backend.Group("rag") + { + rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat")) + } + search := backend.Group("search") { search.GET("", validatorFactory.Create(consts.ValidatorPrefix+"SearchAll")) diff --git a/test/python/embedding_test.go b/test/python/embedding_test.go new file mode 100644 index 0000000..6303d99 --- /dev/null +++ b/test/python/embedding_test.go @@ -0,0 +1,16 @@ +package nlp + +import ( + "catface/app/service/nlp" + _ "catface/bootstrap" + "fmt" + "testing" +) + +func TestEmbeddingApi(t *testing.T) { + res, ok := nlp.GetEmbedding("一段测试文本。") + if !ok { + t.Error("获取嵌入向量失败") + } + fmt.Println(res) +}