finish simple RAG

This commit is contained in:
Havoc412 2024-11-16 02:38:34 +08:00
parent b296677596
commit d330b6b74c
23 changed files with 281 additions and 16 deletions

View File

@ -7,6 +7,7 @@ const (
ErrEncounter ErrEncounter
ErrNlp ErrNlp
ErrKnowledge ErrKnowledge
ErrSubService
) )
const ( const (

View File

@ -14,12 +14,14 @@ func init() {
EnocunterMsgInit(ErrMsg) EnocunterMsgInit(ErrMsg)
NlpMsgInit(ErrMsg) NlpMsgInit(ErrMsg)
KnowledgeMsgInit(ErrMsg) KnowledgeMsgInit(ErrMsg)
SubServiceMsgInit(ErrMsg)
// INGO // INGO
ErrMsgForUser = make(msg) ErrMsgForUser = make(msg)
AnimalMsgUserInit(ErrMsgForUser) AnimalMsgUserInit(ErrMsgForUser)
EncounterMsgUserInit(ErrMsgForUser) EncounterMsgUserInit(ErrMsgForUser)
KnowledgeMsgUserInit(ErrMsgForUser) KnowledgeMsgUserInit(ErrMsgForUser)
NlpMsgUserInit(ErrMsgForUser)
} }
func GeneralMsgInit(m msg) { func GeneralMsgInit(m msg) {

View File

@ -2,8 +2,15 @@ package errcode
const ( const (
ErrNoContent = ErrNlp + iota ErrNoContent = ErrNlp + iota
ErrNoDocFound
) )
func NlpMsgInit(m msg) { func NlpMsgInit(m msg) {
m[ErrNoContent] = "内容为空" m[ErrNoContent] = "内容为空"
m[ErrNoDocFound] = "没有找到相关文档"
}
func NlpMsgUserInit(m msg) {
m[ErrNoContent] = "请输入内容"
m[ErrNoDocFound] = "没有找到相关文档"
} }

View File

@ -0,0 +1,9 @@
package errcode
const (
ErrPythonService = ErrSubService + iota
)
func SubServiceMsgInit(m msg) {
m[ErrPythonService] = "python微服务异常"
}

View File

@ -2,7 +2,7 @@ package web
import ( import (
"catface/app/global/consts" "catface/app/global/consts"
"catface/app/utils/nlp" "catface/app/service/nlp"
"catface/app/utils/response" "catface/app/utils/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"

View File

@ -0,0 +1,49 @@
package web
import (
"catface/app/global/consts"
"catface/app/global/errcode"
"catface/app/global/variable"
"catface/app/model_es"
"catface/app/service/nlp"
"catface/app/utils/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type Rag struct {
}
func (r *Rag) Chat(context *gin.Context) {
// 1. query embedding
query := context.GetString(consts.ValidatorPrefix + "query")
embedding, ok := nlp.GetEmbedding(query)
if !ok {
code := errcode.ErrPythonService
response.Fail(context, code, errcode.ErrMsg[code], "")
return
}
// 2. ES TopK
docs, err := model_es.CreateDocESFactory().TopK(embedding, 1)
if err != nil || len(docs) == 0 {
variable.ZapLog.Error("ES TopK error", zap.Error(err))
code := errcode.ErrNoDocFound
response.Fail(context, code, errcode.ErrMsg[code], errcode.ErrMsgForUser[code])
}
// 3. LLM answer
if answer, err := nlp.ChatKnoledgeRAG(docs[0].Content, query); err == nil {
response.Success(context, consts.CurdStatusOkMsg, gin.H{
"answer": answer,
})
} else {
response.Fail(context, consts.CurdStatusOkCode, consts.CurdStatusOkMsg, "")
}
}
func (r *Rag) HelpDetectCat(context *gin.Context) {
// TODO
}

View File

@ -12,6 +12,7 @@ import (
"catface/app/http/validator/web/encounter_like" "catface/app/http/validator/web/encounter_like"
"catface/app/http/validator/web/knowledge" "catface/app/http/validator/web/knowledge"
"catface/app/http/validator/web/nlp" "catface/app/http/validator/web/nlp"
"catface/app/http/validator/web/rag"
"catface/app/http/validator/web/search" "catface/app/http/validator/web/search"
"catface/app/http/validator/web/users" "catface/app/http/validator/web/users"
) )
@ -85,6 +86,10 @@ func WebRegisterValidator() {
key = consts.ValidatorPrefix + "NlpTitle" key = consts.ValidatorPrefix + "NlpTitle"
containers.Set(key, nlp.Title{}) containers.Set(key, nlp.Title{})
// TAG RAG
key = consts.ValidatorPrefix + "RagDefaultChat"
containers.Set(key, rag.Chat{})
// TAG Search // TAG Search
key = consts.ValidatorPrefix + "SearchAll" key = consts.ValidatorPrefix + "SearchAll"
containers.Set(key, search.SearchAll{}) containers.Set(key, search.SearchAll{})

View File

@ -0,0 +1,31 @@
package rag
import (
"catface/app/global/consts"
"catface/app/http/controller/web"
"catface/app/http/validator/core/data_transfer"
"catface/app/utils/response"
"github.com/gin-gonic/gin"
)
// INFO 虽然起名为 Chat但是默认就会去查询 知识库,也就是不作为一般的 LLM-chat 来使用。
type Chat struct {
Query string `form:"query" json:"query" binding:"required"`
// TODO 这里还需要处理一下历史记录?
}
func (c Chat) CheckParams(context *gin.Context) {
if err := context.ShouldBind(&c); err != nil {
response.ValidatorError(context, err)
return
}
extraAddBindDataContext := data_transfer.DataAddContext(c, consts.ValidatorPrefix, context)
if extraAddBindDataContext == nil {
response.ErrorSystem(context, "RAG CHAT 表单验证器json化失败", "")
} else {
(&web.Rag{}).Chat(extraAddBindDataContext)
}
}

View File

@ -1,5 +1,12 @@
package model_es package model_es
import (
"catface/app/utils/data_bind"
"catface/app/utils/model_handler"
"encoding/json"
"fmt"
)
// INFO @brief 这个文件就是处理 ES 存储文档特征向量的集中处理; // INFO @brief 这个文件就是处理 ES 存储文档特征向量的集中处理;
func CreateDocESFactory() *Doc { func CreateDocESFactory() *Doc {
@ -16,6 +23,51 @@ func (d *Doc) IndexName() string {
return "catface_docs" return "catface_docs"
} }
func (d *Doc) TopK(embedding []float64, k int) ([]Doc, error) {
// 将 embedding 数组转换为 JSON 格式
params := map[string]interface{}{
"query_vector": embedding,
}
paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, err
}
// 构建请求体
body := fmt.Sprintf(`{
"size": %d,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": %s
}
}
},
"_source": ["content"]
}`, k, string(paramsJSON))
hits, err := model_handler.SearchRequest(body, d.IndexName())
if err != nil {
return nil, err
}
var docs []Doc
for _, hit := range hits {
hitMap := hit.(map[string]interface{})
source := hitMap["_source"].(map[string]interface{})
var doc Doc
if err := data_bind.ShouldBindFormMapToModel(source, &doc); err != nil {
continue
}
docs = append(docs, doc)
}
return docs, nil
}
// UPDATE 因为 chunck 还是 Python 来处理会比较方便,所以 Go 这边主要还是处理查询相关的操作。 // UPDATE 因为 chunck 还是 Python 来处理会比较方便,所以 Go 这边主要还是处理查询相关的操作。
// func (d *Doc) InsertDocument() error { // func (d *Doc) InsertDocument() error {
// return nil // return nil

View File

@ -0,0 +1,34 @@
package nlp
import (
"catface/app/global/variable"
"catface/app/utils/micro_service"
"context"
"github.com/carlmjohnson/requests"
)
type EmbeddingRes struct {
Status int `json:"status"`
Message string `json:"message"`
Embedding []float64 `json:"embedding"`
}
func GetEmbedding(text string) ([]float64, bool) {
body := map[string]interface{}{
"text": text,
}
var res EmbeddingRes
err := requests.URL(micro_service.FetchPythonServiceUrl("rag/bge_embedding")).
BodyJSON(&body).
ToJSON(&res).
Fetch(context.Background())
if err != nil {
variable.ZapLog.Error("获取嵌入向量失败: " + err.Error())
}
if res.Status != 200 {
return nil, false
} else {
return res.Embedding, true
}
}

32
app/service/nlp/func.go Normal file
View File

@ -0,0 +1,32 @@
package nlp
import (
"catface/app/global/variable"
"catface/app/service/nlp/glm"
"fmt"
"strings"
)
func GenerateTitle(content string) string {
message := variable.PromptsYml.GetString("Prompt.Title") + content
title, _ := glm.Chat(message)
return title
}
// ChatKnoledgeRAG 使用 RAG 模型进行知识问答
func ChatKnoledgeRAG(doc, query string) (string, error) {
// 读取配置文件中的 KnoledgeRAG 模板
promptTemplate := variable.PromptsYml.GetString("Prompt.KnoledgeRAG")
// 替换模板中的占位符
message := strings.Replace(promptTemplate, "{question}", query, -1)
message = strings.Replace(message, "{context}", doc, -1)
// 调用聊天接口
response, err := glm.Chat(message)
if err != nil {
return "", fmt.Errorf("调用聊天接口失败: %w", err)
}
return response, nil
}

View File

@ -0,0 +1,20 @@
package micro_service
import (
"catface/app/global/variable"
"fmt"
"strings"
)
func FetchPythonServiceUrl(url string) string {
// 检查 url 是否以 / 开头,如果是则去掉开头的 /
if strings.HasPrefix(url, "/") {
url = url[1:]
}
return fmt.Sprintf(`http://%s:%v/%s/%s`,
variable.ConfigYml.GetString("PythonService.Host"),
variable.ConfigYml.GetString("PythonService.Port"),
variable.ConfigYml.GetString("PythonService.TopUrl"),
url)
}

View File

@ -1,12 +0,0 @@
package nlp
import (
"catface/app/global/variable"
"catface/app/utils/nlp/glm"
)
func GenerateTitle(content string) string {
message := variable.PromptsYml.GetString("Prompt.Title") + content
title, _ := glm.Chat(message)
return title
}

View File

@ -163,7 +163,6 @@ Glm:
ApiKey: "0cf510ebc01599dba2a593069c1bdfbc.nQBQ4skP8xBh7ijU" ApiKey: "0cf510ebc01599dba2a593069c1bdfbc.nQBQ4skP8xBh7ijU"
DefaultModel: "glm-4-flash" DefaultModel: "glm-4-flash"
# qiNiu 云存储配置 # qiNiu 云存储配置
QiNiu: QiNiu:
AccessKey: "bI1MpHUBA9OCg4uSJkuJRmScuCJfOlbePe8fCENo" AccessKey: "bI1MpHUBA9OCg4uSJkuJRmScuCJfOlbePe8fCENo"
@ -173,4 +172,9 @@ ElasticSearch:
Addr: "http://localhost:9200" Addr: "http://localhost:9200"
UserName: "elastic" UserName: "elastic"
Password: "" Password: ""
PythonService:
Host: "localhost"
Port: 8000
TopUrl: "api"
# HttpOkCode: 200 # 这个就还是直接硬编码了...

View File

@ -1,2 +1,9 @@
Prompt: Prompt:
Title: "请根据以下长文本生成一个合适的标题不需要书名号长度10字内" Title: "请根据以下长文本生成一个合适的标题不需要书名号长度10字内"
KnoledgeRAG: "使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答知识库中没有这个内容,你不知道。"

1
go.mod
View File

@ -57,6 +57,7 @@ require (
github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
github.com/bytedance/sonic v1.12.3 // indirect github.com/bytedance/sonic v1.12.3 // indirect
github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect
github.com/carlmjohnson/requests v0.24.2
github.com/casbin/casbin/v2 v2.100.0 // indirects github.com/casbin/casbin/v2 v2.100.0 // indirects
github.com/casbin/gorm-adapter/v3 v3.28.0 github.com/casbin/gorm-adapter/v3 v3.28.0
github.com/casbin/govaluate v1.2.0 // indirect github.com/casbin/govaluate v1.2.0 // indirect

2
go.sum
View File

@ -30,6 +30,8 @@ github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKz
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/carlmjohnson/requests v0.24.2 h1:JDakhAmTIKL/qL/1P7Kkc2INGBJIkIFP6xUeUmPzLso=
github.com/carlmjohnson/requests v0.24.2/go.mod h1:duYA/jDnyZ6f3xbcF5PpZ9N8clgopubP2nK5i6MVMhU=
github.com/casbin/casbin/v2 v2.100.0 h1:aeugSNjjHfCrgA22nHkVvw2xsscboHv5r0a13ljQKGQ= github.com/casbin/casbin/v2 v2.100.0 h1:aeugSNjjHfCrgA22nHkVvw2xsscboHv5r0a13ljQKGQ=
github.com/casbin/casbin/v2 v2.100.0/go.mod h1:LO7YPez4dX3LgoTCqSQAleQDo0S0BeZBDxYnPUl95Ng= github.com/casbin/casbin/v2 v2.100.0/go.mod h1:LO7YPez4dX3LgoTCqSQAleQDo0S0BeZBDxYnPUl95Ng=
github.com/casbin/gorm-adapter/v3 v3.28.0 h1:ORF8prF6SfaipdgT1fud+r1Tp5J0uul8QaKJHqCPY/o= github.com/casbin/gorm-adapter/v3 v3.28.0 h1:ORF8prF6SfaipdgT1fud+r1Tp5J0uul8QaKJHqCPY/o=

View File

@ -150,6 +150,11 @@ func InitWebRouter() *gin.Engine {
nlp.POST("title", validatorFactory.Create(consts.ValidatorPrefix+"NlpTitle")) nlp.POST("title", validatorFactory.Create(consts.ValidatorPrefix+"NlpTitle"))
} }
rag := backend.Group("rag")
{
rag.POST("default_talk", validatorFactory.Create(consts.ValidatorPrefix+"RagDefaultChat"))
}
search := backend.Group("search") search := backend.Group("search")
{ {
search.GET("", validatorFactory.Create(consts.ValidatorPrefix+"SearchAll")) search.GET("", validatorFactory.Create(consts.ValidatorPrefix+"SearchAll"))

View File

@ -0,0 +1,16 @@
package nlp
import (
"catface/app/service/nlp"
_ "catface/bootstrap"
"fmt"
"testing"
)
func TestEmbeddingApi(t *testing.T) {
res, ok := nlp.GetEmbedding("一段测试文本。")
if !ok {
t.Error("获取嵌入向量失败")
}
fmt.Println(res)
}