✨ finish simple RAG
This commit is contained in:
parent
b296677596
commit
d330b6b74c
@ -7,6 +7,7 @@ const (
|
||||
ErrEncounter
|
||||
ErrNlp
|
||||
ErrKnowledge
|
||||
ErrSubService
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -14,12 +14,14 @@ func init() {
|
||||
EnocunterMsgInit(ErrMsg)
|
||||
NlpMsgInit(ErrMsg)
|
||||
KnowledgeMsgInit(ErrMsg)
|
||||
SubServiceMsgInit(ErrMsg)
|
||||
|
||||
// INGO
|
||||
ErrMsgForUser = make(msg)
|
||||
AnimalMsgUserInit(ErrMsgForUser)
|
||||
EncounterMsgUserInit(ErrMsgForUser)
|
||||
KnowledgeMsgUserInit(ErrMsgForUser)
|
||||
NlpMsgUserInit(ErrMsgForUser)
|
||||
}
|
||||
|
||||
func GeneralMsgInit(m msg) {
|
||||
|
@ -2,8 +2,15 @@ package errcode
|
||||
|
||||
const (
|
||||
ErrNoContent = ErrNlp + iota
|
||||
ErrNoDocFound
|
||||
)
|
||||
|
||||
func NlpMsgInit(m msg) {
|
||||
m[ErrNoContent] = "内容为空"
|
||||
m[ErrNoDocFound] = "没有找到相关文档"
|
||||
}
|
||||
|
||||
func NlpMsgUserInit(m msg) {
|
||||
m[ErrNoContent] = "请输入内容"
|
||||
m[ErrNoDocFound] = "没有找到相关文档"
|
||||
}
|
||||
|
9
app/global/errcode/subService.go
Normal file
9
app/global/errcode/subService.go
Normal file
@ -0,0 +1,9 @@
|
||||
package errcode
|
||||
|
||||
const (
|
||||
ErrPythonService = ErrSubService + iota
|
||||
)
|
||||
|
||||
func SubServiceMsgInit(m msg) {
|
||||
m[ErrPythonService] = "python微服务异常"
|
||||
}
|
@ -2,7 +2,7 @@ package web
|
||||
|
||||
import (
|
||||
"catface/app/global/consts"
|
||||
"catface/app/utils/nlp"
|
||||
"catface/app/service/nlp"
|
||||
"catface/app/utils/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
49
app/http/controller/web/rag_controller.go
Normal file
49
app/http/controller/web/rag_controller.go
Normal file
@ -0,0 +1,49 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"catface/app/global/consts"
|
||||
"catface/app/global/errcode"
|
||||
"catface/app/global/variable"
|
||||
"catface/app/model_es"
|
||||
"catface/app/service/nlp"
|
||||
"catface/app/utils/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Rag struct {
|
||||
}
|
||||
|
||||
func (r *Rag) Chat(context *gin.Context) {
|
||||
// 1. query embedding
|
||||
query := context.GetString(consts.ValidatorPrefix + "query")
|
||||
embedding, ok := nlp.GetEmbedding(query)
|
||||
if !ok {
|
||||
code := errcode.ErrPythonService
|
||||
response.Fail(context, code, errcode.ErrMsg[code], "")
|
||||
return
|
||||
}
|
||||
|
||||
// 2. ES TopK
|
||||
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
|
||||
if err != nil || len(docs) == 0 {
|
||||
variable.ZapLog.Error("ES TopK error", zap.Error(err))
|
||||
|
||||
code := errcode.ErrNoDocFound
|
||||
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
|
||||
}
|
||||
|
||||
// 3. LLM answer
|
||||
if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
|
||||
response.Success(context, consts.CurdStatusOkMsg, gin.H{
|
||||
"answer": answer,
|
||||
})
|
||||
} else {
|
||||
response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Rag) HelpDetectCat(context *gin.Context) {
|
||||
// TODO
|
||||
}
|
@ -12,6 +12,7 @@ import (
|
||||
"catface/app/http/validator/web/encounter_like"
|
||||
"catface/app/http/validator/web/knowledge"
|
||||
"catface/app/http/validator/web/nlp"
|
||||
"catface/app/http/validator/web/rag"
|
||||
"catface/app/http/validator/web/search"
|
||||
"catface/app/http/validator/web/users"
|
||||
)
|
||||
@ -85,6 +86,10 @@ func WebRegisterValidator() {
|
||||
key = consts.ValidatorPrefix + "NlpTitle"
|
||||
containers.Set(key, nlp.Title{})
|
||||
|
||||
// TAG RAG
|
||||
key = consts.ValidatorPrefix + "RagDefaultChat"
|
||||
containers.Set(key, rag.Chat{})
|
||||
|
||||
// TAG Search
|
||||
key = consts.ValidatorPrefix + "SearchAll"
|
||||
containers.Set(key, search.SearchAll{})
|
||||
|
31
app/http/validator/web/rag/chat.go
Normal file
31
app/http/validator/web/rag/chat.go
Normal file
@ -0,0 +1,31 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"catface/app/global/consts"
|
||||
"catface/app/http/controller/web"
|
||||
"catface/app/http/validator/core/data_transfer"
|
||||
"catface/app/utils/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// INFO 虽然起名为 Chat,但是默认就会去查询 知识库,也就是不作为一般的 LLM-chat 来使用。
|
||||
type Chat struct {
|
||||
Query string `form:"query" json:"query" binding:"required"`
|
||||
// TODO 这里还需要处理一下历史记录?
|
||||
}
|
||||
|
||||
func (c Chat) CheckParams(context *gin.Context) {
|
||||
if err := context.ShouldBind(&c); err != nil {
|
||||
response.ValidatorError(context, err)
|
||||
return
|
||||
}
|
||||
|
||||
extraAddBindDataContext := data_transfer.DataAddContext(c, consts.ValidatorPrefix, context)
|
||||
if extraAddBindDataContext == nil {
|
||||
response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "")
|
||||
} else {
|
||||
(&web.Rag{}).Chat(extraAddBindDataContext)
|
||||
}
|
||||
|
||||
}
|
@ -1,5 +1,12 @@
|
||||
package model_es
|
||||
|
||||
import (
|
||||
"catface/app/utils/data_bind"
|
||||
"catface/app/utils/model_handler"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// INFO @brief 这个文件就是处理 ES 存储文档特征向量的集中处理;
|
||||
|
||||
func CreateDocESFactory() *Doc {
|
||||
@ -16,6 +23,51 @@ func (d *Doc) IndexName() string {
|
||||
return "catface_docs"
|
||||
}
|
||||
|
||||
func (d *Doc) TopK(embedding []float64, k int) ([]Doc, error) {
|
||||
// 将 embedding 数组转换为 JSON 格式
|
||||
params := map[string]interface{}{
|
||||
"query_vector": embedding,
|
||||
}
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建请求体
|
||||
body := fmt.Sprintf(`{
|
||||
"size": %d,
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": %s
|
||||
}
|
||||
}
|
||||
},
|
||||
"_source": ["content"]
|
||||
}`, k, string(paramsJSON))
|
||||
|
||||
hits, err := model_handler.SearchRequest(body, d.IndexName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var docs []Doc
|
||||
for _, hit := range hits {
|
||||
hitMap := hit.(map[string]interface{})
|
||||
source := hitMap["_source"].(map[string]interface{})
|
||||
var doc Doc
|
||||
if err := data_bind.ShouldBindFormMapToModel(source, &doc); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// UPDATE 因为 chunck 还是 Python 来处理会比较方便,所以 Go 这边主要还是处理查询相关的操作。
|
||||
// func (d *Doc) InsertDocument() error {
|
||||
// return nil
|
||||
|
34
app/service/nlp/embedding.go
Normal file
34
app/service/nlp/embedding.go
Normal file
@ -0,0 +1,34 @@
|
||||
package nlp
|
||||
|
||||
import (
|
||||
"catface/app/global/variable"
|
||||
"catface/app/utils/micro_service"
|
||||
"context"
|
||||
|
||||
"github.com/carlmjohnson/requests"
|
||||
)
|
||||
|
||||
type EmbeddingRes struct {
|
||||
Status int `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func GetEmbedding(text string) ([]float64, bool) {
|
||||
body := map[string]interface{}{
|
||||
"text": text,
|
||||
}
|
||||
var res EmbeddingRes
|
||||
err := requests.URL(micro_service.FetchPythonServiceUrl("rag/bge_embedding")).
|
||||
BodyJSON(&body).
|
||||
ToJSON(&res).
|
||||
Fetch(context.Background())
|
||||
if err != nil {
|
||||
variable.ZapLog.Error("获取嵌入向量失败: " + err.Error())
|
||||
}
|
||||
if res.Status != 200 {
|
||||
return nil, false
|
||||
} else {
|
||||
return res.Embedding, true
|
||||
}
|
||||
}
|
32
app/service/nlp/func.go
Normal file
32
app/service/nlp/func.go
Normal file
@ -0,0 +1,32 @@
|
||||
package nlp
|
||||
|
||||
import (
|
||||
"catface/app/global/variable"
|
||||
"catface/app/service/nlp/glm"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GenerateTitle(content string) string {
|
||||
message := variable.PromptsYml.GetString("Prompt.Title") + content
|
||||
title, _ := glm.Chat(message)
|
||||
return title
|
||||
}
|
||||
|
||||
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
|
||||
func ChatKnoledgeRAG(doc, query string) (string, error) {
|
||||
// 读取配置文件中的 KnoledgeRAG 模板
|
||||
promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG")
|
||||
|
||||
// 替换模板中的占位符
|
||||
message := strings.Replace(promptTemplate, "{question}", query, -1)
|
||||
message = strings.Replace(message, "{context}", doc, -1)
|
||||
|
||||
// 调用聊天接口
|
||||
response, err := glm.Chat(message)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("调用聊天接口失败: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
20
app/utils/micro_service/micro_service.go
Normal file
20
app/utils/micro_service/micro_service.go
Normal file
@ -0,0 +1,20 @@
|
||||
package micro_service
|
||||
|
||||
import (
|
||||
"catface/app/global/variable"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func FetchPythonServiceUrl(url string) string {
|
||||
// 检查 url 是否以 / 开头,如果是则去掉开头的 /
|
||||
if strings.HasPrefix(url, "/") {
|
||||
url = url[1:]
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`http://%s:%v/%s/%s`,
|
||||
variable.ConfigYml.GetString("PythonService.Host"),
|
||||
variable.ConfigYml.GetString("PythonService.Port"),
|
||||
variable.ConfigYml.GetString("PythonService.TopUrl"),
|
||||
url)
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package nlp
|
||||
|
||||
import (
|
||||
"catface/app/global/variable"
|
||||
"catface/app/utils/nlp/glm"
|
||||
)
|
||||
|
||||
func GenerateTitle(content string) string {
|
||||
message := variable.PromptsYml.GetString("Prompt.Title") + content
|
||||
title, _ := glm.Chat(message)
|
||||
return title
|
||||
}
|
@ -163,7 +163,6 @@ Glm:
|
||||
ApiKey: "0cf510ebc01599dba2a593069c1bdfbc.nQBQ4skP8xBh7ijU"
|
||||
DefaultModel: "glm-4-flash"
|
||||
|
||||
|
||||
# qiNiu 云存储配置
|
||||
QiNiu:
|
||||
AccessKey: "bI1MpHUBA9OCg4uSJkuJRmScuCJfOlbePe8fCENo"
|
||||
@ -173,4 +172,9 @@ ElasticSearch:
|
||||
Addr: "http://localhost:9200"
|
||||
UserName: "elastic"
|
||||
Password: ""
|
||||
|
||||
|
||||
PythonService:
|
||||
Host: "localhost"
|
||||
Port: 8000
|
||||
TopUrl: "api"
|
||||
# HttpOkCode: 200 # 这个就还是直接硬编码了...
|
||||
|
@ -1,2 +1,9 @@
|
||||
Prompt:
|
||||
Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:"
|
||||
Title: "请根据以下长文本生成一个合适的标题,不需要书名号,长度10字内:"
|
||||
KnoledgeRAG: "使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答知识库中没有这个内容,你不知道。"
|
1
go.mod
1
go.mod
@ -57,6 +57,7 @@ require (
|
||||
github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
|
||||
github.com/bytedance/sonic v1.12.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.0 // indirect
|
||||
github.com/carlmjohnson/requests v0.24.2
|
||||
github.com/casbin/casbin/v2 v2.100.0 // indirects
|
||||
github.com/casbin/gorm-adapter/v3 v3.28.0
|
||||
github.com/casbin/govaluate v1.2.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -30,6 +30,8 @@ github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKz
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
|
||||
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/carlmjohnson/requests v0.24.2 h1:JDakhAmTIKL/qL/1P7Kkc2INGBJIkIFP6xUeUmPzLso=
|
||||
github.com/carlmjohnson/requests v0.24.2/go.mod h1:duYA/jDnyZ6f3xbcF5PpZ9N8clgopubP2nK5i6MVMhU=
|
||||
github.com/casbin/casbin/v2 v2.100.0 h1:aeugSNjjHfCrgA22nHkVvw2xsscboHv5r0a13ljQKGQ=
|
||||
github.com/casbin/casbin/v2 v2.100.0/go.mod h1:LO7YPez4dX3LgoTCqSQAleQDo0S0BeZBDxYnPUl95Ng=
|
||||
github.com/casbin/gorm-adapter/v3 v3.28.0 h1:ORF8prF6SfaipdgT1fud+r1Tp5J0uul8QaKJHqCPY/o=
|
||||
|
@ -150,6 +150,11 @@ func InitWebRouter() *gin.Engine {
|
||||
nlp.POST("title", validatorFactory.Create(consts.ValidatorPrefix+"NlpTitle"))
|
||||
}
|
||||
|
||||
rag := backend.Group("rag")
|
||||
{
|
||||
rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat"))
|
||||
}
|
||||
|
||||
search := backend.Group("search")
|
||||
{
|
||||
search.GET("", validatorFactory.Create(consts.ValidatorPrefix+"SearchAll"))
|
||||
|
16
test/python/embedding_test.go
Normal file
16
test/python/embedding_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package nlp
|
||||
|
||||
import (
|
||||
"catface/app/service/nlp"
|
||||
_ "catface/bootstrap"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEmbeddingApi(t *testing.T) {
|
||||
res, ok := nlp.GetEmbedding("一段测试文本。")
|
||||
if !ok {
|
||||
t.Error("获取嵌入向量失败")
|
||||
}
|
||||
fmt.Println(res)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user