pgx/pgmock/pgmock.go

502 lines
10 KiB
Go

package pgmock
import (
"errors"
"fmt"
"io"
"net"
"reflect"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
type Server struct {
ln net.Listener
controller Controller
}
func NewServer(controller Controller) (*Server, error) {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, err
}
server := &Server{
ln: ln,
controller: controller,
}
return server, nil
}
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *Server) ServeOne() error {
conn, err := s.ln.Accept()
if err != nil {
return err
}
defer conn.Close()
s.Close()
backend, err := pgproto3.NewBackend(conn, conn)
if err != nil {
conn.Close()
return err
}
return s.controller.Serve(backend)
}
func (s *Server) Close() error {
err := s.ln.Close()
if err != nil {
return err
}
return nil
}
type Controller interface {
Serve(backend *pgproto3.Backend) error
}
type Step interface {
Step(*pgproto3.Backend) error
}
type Script struct {
Steps []Step
}
func (s *Script) Run(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Serve(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Step(backend *pgproto3.Backend) error {
return s.Serve(backend)
}
type expectMessageStep struct {
want pgproto3.FrontendMessage
any bool
}
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.Receive()
if err != nil {
return err
}
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
type expectStartupMessageStep struct {
want *pgproto3.StartupMessage
any bool
}
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.ReceiveStartupMessage()
if err != nil {
return err
}
if e.any {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
func ExpectMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, false)
}
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, true)
}
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
if want, ok := want.(*pgproto3.StartupMessage); ok {
return &expectStartupMessageStep{want: want, any: any}
}
return &expectMessageStep{want: want, any: any}
}
type sendMessageStep struct {
msg pgproto3.BackendMessage
}
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
return backend.Send(e.msg)
}
func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg}
}
type waitForCloseMessageStep struct{}
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}
func WaitForClose() Step {
return &waitForCloseMessageStep{}
}
func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}
func PgxInitSteps() []Step {
steps := []Step{
ExpectMessage(&pgproto3.Parse{
Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)",
}),
ExpectMessage(&pgproto3.Describe{
ObjectType: 'S',
}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.ParseComplete{}),
SendMessage(&pgproto3.ParameterDescription{}),
SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{Name: "oid",
TableOID: 1247,
TableAttributeNumber: 65534,
DataTypeOID: 26,
DataTypeSize: 4,
TypeModifier: 4294967295,
Format: 0,
},
{Name: "typname",
TableOID: 1247,
TableAttributeNumber: 1,
DataTypeOID: 19,
DataTypeSize: 64,
TypeModifier: 4294967295,
Format: 0,
},
},
}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
ExpectMessage(&pgproto3.Bind{
ResultFormatCodes: []int16{1, 1},
}),
ExpectMessage(&pgproto3.Execute{}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.BindComplete{}),
}
rowVals := []struct {
oid pgtype.Oid
name string
}{
{16, "bool"},
{17, "bytea"},
{18, "char"},
{19, "name"},
{20, "int8"},
{21, "int2"},
{22, "int2vector"},
{23, "int4"},
{24, "regproc"},
{25, "text"},
{26, "oid"},
{27, "tid"},
{28, "xid"},
{29, "cid"},
{30, "oidvector"},
{114, "json"},
{142, "xml"},
{143, "_xml"},
{199, "_json"},
{194, "pg_node_tree"},
{32, "pg_ddl_command"},
{210, "smgr"},
{600, "point"},
{601, "lseg"},
{602, "path"},
{603, "box"},
{604, "polygon"},
{628, "line"},
{629, "_line"},
{700, "float4"},
{701, "float8"},
{702, "abstime"},
{703, "reltime"},
{704, "tinterval"},
{705, "unknown"},
{718, "circle"},
{719, "_circle"},
{790, "money"},
{791, "_money"},
{829, "macaddr"},
{869, "inet"},
{650, "cidr"},
{1000, "_bool"},
{1001, "_bytea"},
{1002, "_char"},
{1003, "_name"},
{1005, "_int2"},
{1006, "_int2vector"},
{1007, "_int4"},
{1008, "_regproc"},
{1009, "_text"},
{1028, "_oid"},
{1010, "_tid"},
{1011, "_xid"},
{1012, "_cid"},
{1013, "_oidvector"},
{1014, "_bpchar"},
{1015, "_varchar"},
{1016, "_int8"},
{1017, "_point"},
{1018, "_lseg"},
{1019, "_path"},
{1020, "_box"},
{1021, "_float4"},
{1022, "_float8"},
{1023, "_abstime"},
{1024, "_reltime"},
{1025, "_tinterval"},
{1027, "_polygon"},
{1033, "aclitem"},
{1034, "_aclitem"},
{1040, "_macaddr"},
{1041, "_inet"},
{651, "_cidr"},
{1263, "_cstring"},
{1042, "bpchar"},
{1043, "varchar"},
{1082, "date"},
{1083, "time"},
{1114, "timestamp"},
{1115, "_timestamp"},
{1182, "_date"},
{1183, "_time"},
{1184, "timestamptz"},
{1185, "_timestamptz"},
{1186, "interval"},
{1187, "_interval"},
{1231, "_numeric"},
{1266, "timetz"},
{1270, "_timetz"},
{1560, "bit"},
{1561, "_bit"},
{1562, "varbit"},
{1563, "_varbit"},
{1700, "numeric"},
{1790, "refcursor"},
{2201, "_refcursor"},
{2202, "regprocedure"},
{2203, "regoper"},
{2204, "regoperator"},
{2205, "regclass"},
{2206, "regtype"},
{4096, "regrole"},
{4089, "regnamespace"},
{2207, "_regprocedure"},
{2208, "_regoper"},
{2209, "_regoperator"},
{2210, "_regclass"},
{2211, "_regtype"},
{4097, "_regrole"},
{4090, "_regnamespace"},
{2950, "uuid"},
{2951, "_uuid"},
{3220, "pg_lsn"},
{3221, "_pg_lsn"},
{3614, "tsvector"},
{3642, "gtsvector"},
{3615, "tsquery"},
{3734, "regconfig"},
{3769, "regdictionary"},
{3643, "_tsvector"},
{3644, "_gtsvector"},
{3645, "_tsquery"},
{3735, "_regconfig"},
{3770, "_regdictionary"},
{3802, "jsonb"},
{3807, "_jsonb"},
{2970, "txid_snapshot"},
{2949, "_txid_snapshot"},
{3904, "int4range"},
{3905, "_int4range"},
{3906, "numrange"},
{3907, "_numrange"},
{3908, "tsrange"},
{3909, "_tsrange"},
{3910, "tstzrange"},
{3911, "_tstzrange"},
{3912, "daterange"},
{3913, "_daterange"},
{3926, "int8range"},
{3927, "_int8range"},
{2249, "record"},
{2287, "_record"},
{2275, "cstring"},
{2276, "any"},
{2277, "anyarray"},
{2278, "void"},
{2279, "trigger"},
{3838, "event_trigger"},
{2280, "language_handler"},
{2281, "internal"},
{2282, "opaque"},
{2283, "anyelement"},
{2776, "anynonarray"},
{3500, "anyenum"},
{3115, "fdw_handler"},
{325, "index_am_handler"},
{3310, "tsm_handler"},
{3831, "anyrange"},
{51367, "gbtreekey4"},
{51370, "_gbtreekey4"},
{51371, "gbtreekey8"},
{51374, "_gbtreekey8"},
{51375, "gbtreekey16"},
{51378, "_gbtreekey16"},
{51379, "gbtreekey32"},
{51382, "_gbtreekey32"},
{51383, "gbtreekey_var"},
{51386, "_gbtreekey_var"},
{51921, "hstore"},
{51926, "_hstore"},
{52005, "ghstore"},
{52008, "_ghstore"},
}
for _, rv := range rowVals {
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
steps = append(steps, step)
}
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return steps
}
type dataRowValue struct {
Value interface{}
FormatCode int16
}
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
dr, err := buildDataRow(values, formatCodes)
if err != nil {
panic(err)
}
return dr
}
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
dr := &pgproto3.DataRow{
Values: make([][]byte, len(values)),
}
if len(formatCodes) == 1 {
for i := 1; i < len(values); i++ {
formatCodes = append(formatCodes, formatCodes[0])
}
}
for i := range values {
switch v := values[i].(type) {
case string:
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
case int16:
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
case int32:
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
case int64:
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
}
}
for i := range values {
switch formatCodes[i] {
case pgproto3.TextFormat:
if e, ok := values[i].(pgtype.TextEncoder); ok {
buf, err := e.EncodeText(nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i)
}
case pgproto3.BinaryFormat:
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
buf, err := e.EncodeBinary(nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i)
}
default:
return nil, errors.New("unknown FormatCode")
}
}
return dr, nil
}