gogs/internal/database/schemadoc/main.go

197 lines
4.7 KiB
Go

package main
import (
"fmt"
"log"
"os"
"sort"
"strings"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"
"gopkg.in/DATA-DOG/go-sqlmock.v2"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gogs.io/gogs/internal/database"
)
//go:generate go run main.go ../../../docs/dev/database_schema.md
func main() {
w, err := os.Create(os.Args[1])
if err != nil {
log.Fatalf("Failed to create file: %v", err)
}
defer func() { _ = w.Close() }()
conn, _, err := sqlmock.New()
if err != nil {
log.Fatalf("Failed to get mock connection: %v", err)
}
defer func() { _ = conn.Close() }()
dialectors := []gorm.Dialector{
postgres.New(postgres.Config{
Conn: conn,
}),
mysql.New(mysql.Config{
Conn: conn,
SkipInitializeWithVersion: true,
}),
sqlite.Open(""),
}
collected := make([][]*tableInfo, 0, len(dialectors))
for i, dialector := range dialectors {
tableInfos, err := generate(dialector)
if err != nil {
log.Fatalf("Failed to get table info of %d: %v", i, err)
}
collected = append(collected, tableInfos)
}
for i, ti := range collected[0] {
_, _ = w.WriteString(`# Table "` + ti.Name + `"`)
_, _ = w.WriteString("\n\n")
_, _ = w.WriteString("```\n")
table := tablewriter.NewWriter(w)
table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
table.SetBorder(false)
for j, f := range ti.Fields {
table.Append([]string{
f.Name, f.Column,
strings.ToUpper(f.Type), // PostgreSQL
strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3
})
}
table.Render()
_, _ = w.WriteString("\n")
_, _ = w.WriteString("Primary keys: ")
_, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
_, _ = w.WriteString("\n")
if len(ti.Indexes) > 0 {
_, _ = w.WriteString("Indexes: \n")
for _, index := range ti.Indexes {
_, _ = w.WriteString(fmt.Sprintf("\t%q", index.Name))
if index.Class != "" {
_, _ = w.WriteString(fmt.Sprintf(" %s", index.Class))
}
if index.Type != "" {
_, _ = w.WriteString(fmt.Sprintf(", %s", index.Type))
}
if len(index.Fields) > 0 {
fields := make([]string, len(index.Fields))
for i := range index.Fields {
fields[i] = index.Fields[i].DBName
}
_, _ = w.WriteString(fmt.Sprintf(" (%s)", strings.Join(fields, ", ")))
}
_, _ = w.WriteString("\n")
}
}
_, _ = w.WriteString("```\n\n")
}
}
type tableField struct {
Name string
Column string
Type string
}
type tableInfo struct {
Name string
Fields []*tableField
PrimaryKeys []string
Indexes []schema.Index
}
// This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
conn, err := gorm.Open(dialector,
&gorm.Config{
SkipDefaultTransaction: true,
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
DryRun: true,
DisableAutomaticPing: true,
},
)
if err != nil {
return nil, errors.Wrap(err, "open database")
}
m := conn.Migrator().(interface {
RunWithValue(value any, fc func(*gorm.Statement) error) error
FullDataTypeOf(*schema.Field) clause.Expr
})
tableInfos := make([]*tableInfo, 0, len(database.Tables))
for _, table := range database.Tables {
err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
for _, field := range stmt.Schema.Fields {
if field.DBName == "" {
continue
}
tags := make([]string, 0)
for tag := range field.TagSettings {
if tag == "UNIQUE" {
tags = append(tags, tag)
}
}
typeSuffix := ""
if len(tags) > 0 {
typeSuffix = " " + strings.Join(tags, " ")
}
fields = append(fields, &tableField{
Name: field.Name,
Column: field.DBName,
Type: m.FullDataTypeOf(field).SQL + typeSuffix,
})
}
primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
if len(stmt.Schema.PrimaryFields) > 0 {
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, field.DBName)
}
}
var indexes []schema.Index
for _, index := range stmt.Schema.ParseIndexes() {
indexes = append(indexes, index)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i].Name < indexes[j].Name
})
tableInfos = append(tableInfos, &tableInfo{
Name: stmt.Table,
Fields: fields,
PrimaryKeys: primaryKeys,
Indexes: indexes,
})
return nil
})
if err != nil {
return nil, errors.Wrap(err, "gather table information")
}
}
return tableInfos, nil
}