249 lines
6.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
"time"
. "catface/AutoMigrateMySQL/config"
"gorm.io/datatypes"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// 仿照 AutoMigrate 原本的效果
func convertToSnakeCase(name string) string {
// 使用正则表达式找到大写字符并在前面加上下划线,然后转换为小写
re := regexp.MustCompile("([a-z0-9])([A-Z])")
snake := re.ReplaceAllString(name, "${1}_${2}")
return strings.ToLower(snake)
}
// 从 AST 中提取结构体字段类型
func getFieldType(expr ast.Expr) reflect.Type {
switch t := expr.(type) {
case *ast.Ident:
// fmt.Println("t.Name:", t.Name)
switch t.Name {
case "string":
return reflect.TypeOf("")
case "int":
return reflect.TypeOf(0)
case "bool":
return reflect.TypeOf(true)
case "uint8":
return reflect.TypeOf(uint8(0))
case "uint16":
return reflect.TypeOf(uint16(0))
case "uint32":
return reflect.TypeOf(uint32(0))
case "uint64":
return reflect.TypeOf(uint64(0))
case "float64":
return reflect.TypeOf(float64(0))
}
case *ast.ArrayType:
elemType := getFieldType(t.Elt)
if elemType != nil {
return reflect.SliceOf(elemType)
}
case *ast.SelectorExpr: // info time.Time 的特化识别
if pkgIdent, ok := t.X.(*ast.Ident); ok {
if pkgIdent.Name == "time" && t.Sel.Name == "Time" {
return reflect.TypeOf(time.Time{})
}
if pkgIdent.Name == "datatypes" && t.Sel.Name == "JSON" {
return reflect.TypeOf(datatypes.JSON{})
}
}
case *ast.StarExpr: // question 暂时好像不影响。
// Handle pointer to a type
return reflect.PtrTo(getFieldType(t.X))
}
return nil
}
// 用于保存结构体信息的map
var structs = make(map[string]reflect.Type)
// 遍历文件中的所有声明
func getStruct(fDecls []ast.Decl) {
for _, decl := range fDecls {
// 检查声明是否为类型声明type T struct {...})
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
// 遍历类型声明中的所有规格可能有多个类型在一个声明中例如type (A struct{}; B struct{})
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
// 检查类型是否为结构体
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
// 过滤空表
if len(structType.Fields.List) == 0 {
continue
}
// 构建反射类型
fields := make([]reflect.StructField, 0)
// fmt.Println(typeSpec.Name.Name, len(structType.Fields.List))
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
// 处理嵌入结构体
ident, ok := field.Type.(*ast.Ident)
if !ok {
log.Printf("Unsupported embedded type for field %v\n", field.Type)
continue
}
embedType, ok := structs[ident.Name]
if !ok {
log.Printf("Embedded type %s not found\n", ident.Name)
continue
}
// 获取嵌入结构体的所有字段
for i := 0; i < embedType.NumField(); i++ {
fields = append(fields, embedType.Field(i))
}
} else {
for _, fieldName := range field.Names {
fieldType := getFieldType(field.Type)
if fieldType == nil {
continue
}
// 处理标签
tag := ""
if field.Tag != nil {
tag = field.Tag.Value
}
fields = append(fields, reflect.StructField{
Name: fieldName.Name,
Type: fieldType,
Tag: reflect.StructTag(tag),
})
// fmt.Println(fieldName.Name, field.Type, fieldType, tag)
}
}
}
// 创建结构体类型
structName := typeSpec.Name.Name
structReflectType := reflect.StructOf(fields)
structs[structName] = structReflectType
fmt.Println("get struct: ", structName)
}
}
}
func autoMigrate() {
config, err := LoadConfig("config.json")
if err != nil {
log.Fatalln("Error loading config: %v", err)
}
// info 初始化数据库
dsn := fmt.Sprintf(
"%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.MySQL.Username,
config.MySQL.Password,
config.MySQL.Host,
config.MySQL.Database,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) // 打开 DB 连接
if err != nil {
log.Fatal(err)
}
// 通过反射创建结构体实例并迁移数据库
for name, typ := range structs {
if name == "General" {
continue
}
instance := reflect.New(typ).Interface()
// 手动设定表名
tableName := convertToSnakeCase(name)
db = db.Table(tableName)
fmt.Printf("Created instance of %s: %+v\n", name, instance)
if err := db.AutoMigrate(instance); err != nil {
log.Fatalf("Failed to migrate %s: %v", name, err)
}
}
}
func main() {
const dirPath = "./table_defs" // 指定目录路径
const rootFileName = "table_defs.go" // info 根结构体所在的文件
fset := token.NewFileSet() // 创建文件集,用于记录位置
// mark stage-1
// 列出指定目录下的所有文件和子目录
entries, err := os.ReadDir(dirPath)
if err != nil {
log.Fatal(err)
}
// 前置根文件
var targetIndex int
var found bool
for i, entry := range entries {
if entry.Name() == rootFileName {
targetIndex = i
found = true
break
}
}
if found {
targetEntry := entries[targetIndex]
entries = append(entries[:targetIndex], entries[targetIndex+1:]...)
entries = append([]os.DirEntry{targetEntry}, entries...)
} else {
log.Fatalf("File %s not found in directory %s", rootFileName, dirPath)
}
// 正常遍历
for _, entry := range entries {
// 构建完整路径
path := filepath.Join(dirPath, entry.Name())
// 检查文件后缀是否为 .go
if !entry.IsDir() && filepath.Ext(path) == ".go" {
// 解析文件,得到 *ast.File 结构
f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
if err != nil {
log.Printf("Error parsing file %s: %s", path, err)
continue
}
// mark stage-2
getStruct(f.Decls)
log.Printf("Parsed file: %s", path)
}
}
// mark stage-3
autoMigrate()
}