package pgmock import ( "errors" "fmt" "io" "net" "reflect" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) type Server struct { ln net.Listener controller Controller } func NewServer(controller Controller) (*Server, error) { ln, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return nil, err } server := &Server{ ln: ln, controller: controller, } return server, nil } func (s *Server) Addr() net.Addr { return s.ln.Addr() } func (s *Server) ServeOne() error { conn, err := s.ln.Accept() if err != nil { return err } defer conn.Close() s.Close() backend, err := pgproto3.NewBackend(conn, conn) if err != nil { conn.Close() return err } return s.controller.Serve(backend) } func (s *Server) Close() error { err := s.ln.Close() if err != nil { return err } return nil } type Controller interface { Serve(backend *pgproto3.Backend) error } type Step interface { Step(*pgproto3.Backend) error } type Script struct { Steps []Step } func (s *Script) Run(backend *pgproto3.Backend) error { for _, step := range s.Steps { err := step.Step(backend) if err != nil { return err } } return nil } func (s *Script) Serve(backend *pgproto3.Backend) error { for _, step := range s.Steps { err := step.Step(backend) if err != nil { return err } } return nil } func (s *Script) Step(backend *pgproto3.Backend) error { return s.Serve(backend) } type expectMessageStep struct { want pgproto3.FrontendMessage any bool } func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { msg, err := backend.Receive() if err != nil { return err } if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { return nil } if !reflect.DeepEqual(msg, e.want) { return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) } return nil } type expectStartupMessageStep struct { want *pgproto3.StartupMessage any bool } func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { msg, err := backend.ReceiveStartupMessage() if err != nil { return err } if e.any { return nil } if !reflect.DeepEqual(msg, e.want) { return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) } return nil } func ExpectMessage(want pgproto3.FrontendMessage) Step { return expectMessage(want, false) } func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { return expectMessage(want, true) } func expectMessage(want pgproto3.FrontendMessage, any bool) Step { if want, ok := want.(*pgproto3.StartupMessage); ok { return &expectStartupMessageStep{want: want, any: any} } return &expectMessageStep{want: want, any: any} } type sendMessageStep struct { msg pgproto3.BackendMessage } func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { return backend.Send(e.msg) } func SendMessage(msg pgproto3.BackendMessage) Step { return &sendMessageStep{msg: msg} } type waitForCloseMessageStep struct{} func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { for { msg, err := backend.Receive() if err == io.EOF { return nil } else if err != nil { return err } if _, ok := msg.(*pgproto3.Terminate); ok { return nil } } } func WaitForClose() Step { return &waitForCloseMessageStep{} } func AcceptUnauthenticatedConnRequestSteps() []Step { return []Step{ ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}), SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), } } func PgxInitSteps() []Step { steps := []Step{ ExpectMessage(&pgproto3.Parse{ Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", }), ExpectMessage(&pgproto3.Describe{ ObjectType: 'S', }), ExpectMessage(&pgproto3.Sync{}), SendMessage(&pgproto3.ParseComplete{}), SendMessage(&pgproto3.ParameterDescription{}), SendMessage(&pgproto3.RowDescription{ Fields: []pgproto3.FieldDescription{ {Name: "oid", TableOID: 1247, TableAttributeNumber: 65534, DataTypeOID: 26, DataTypeSize: 4, TypeModifier: 4294967295, Format: 0, }, {Name: "typname", TableOID: 1247, TableAttributeNumber: 1, DataTypeOID: 19, DataTypeSize: 64, TypeModifier: 4294967295, Format: 0, }, }, }), SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), ExpectMessage(&pgproto3.Bind{ ResultFormatCodes: []int16{1, 1}, }), ExpectMessage(&pgproto3.Execute{}), ExpectMessage(&pgproto3.Sync{}), SendMessage(&pgproto3.BindComplete{}), } rowVals := []struct { oid pgtype.Oid name string }{ {16, "bool"}, {17, "bytea"}, {18, "char"}, {19, "name"}, {20, "int8"}, {21, "int2"}, {22, "int2vector"}, {23, "int4"}, {24, "regproc"}, {25, "text"}, {26, "oid"}, {27, "tid"}, {28, "xid"}, {29, "cid"}, {30, "oidvector"}, {114, "json"}, {142, "xml"}, {143, "_xml"}, {199, "_json"}, {194, "pg_node_tree"}, {32, "pg_ddl_command"}, {210, "smgr"}, {600, "point"}, {601, "lseg"}, {602, "path"}, {603, "box"}, {604, "polygon"}, {628, "line"}, {629, "_line"}, {700, "float4"}, {701, "float8"}, {702, "abstime"}, {703, "reltime"}, {704, "tinterval"}, {705, "unknown"}, {718, "circle"}, {719, "_circle"}, {790, "money"}, {791, "_money"}, {829, "macaddr"}, {869, "inet"}, {650, "cidr"}, {1000, "_bool"}, {1001, "_bytea"}, {1002, "_char"}, {1003, "_name"}, {1005, "_int2"}, {1006, "_int2vector"}, {1007, "_int4"}, {1008, "_regproc"}, {1009, "_text"}, {1028, "_oid"}, {1010, "_tid"}, {1011, "_xid"}, {1012, "_cid"}, {1013, "_oidvector"}, {1014, "_bpchar"}, {1015, "_varchar"}, {1016, "_int8"}, {1017, "_point"}, {1018, "_lseg"}, {1019, "_path"}, {1020, "_box"}, {1021, "_float4"}, {1022, "_float8"}, {1023, "_abstime"}, {1024, "_reltime"}, {1025, "_tinterval"}, {1027, "_polygon"}, {1033, "aclitem"}, {1034, "_aclitem"}, {1040, "_macaddr"}, {1041, "_inet"}, {651, "_cidr"}, {1263, "_cstring"}, {1042, "bpchar"}, {1043, "varchar"}, {1082, "date"}, {1083, "time"}, {1114, "timestamp"}, {1115, "_timestamp"}, {1182, "_date"}, {1183, "_time"}, {1184, "timestamptz"}, {1185, "_timestamptz"}, {1186, "interval"}, {1187, "_interval"}, {1231, "_numeric"}, {1266, "timetz"}, {1270, "_timetz"}, {1560, "bit"}, {1561, "_bit"}, {1562, "varbit"}, {1563, "_varbit"}, {1700, "numeric"}, {1790, "refcursor"}, {2201, "_refcursor"}, {2202, "regprocedure"}, {2203, "regoper"}, {2204, "regoperator"}, {2205, "regclass"}, {2206, "regtype"}, {4096, "regrole"}, {4089, "regnamespace"}, {2207, "_regprocedure"}, {2208, "_regoper"}, {2209, "_regoperator"}, {2210, "_regclass"}, {2211, "_regtype"}, {4097, "_regrole"}, {4090, "_regnamespace"}, {2950, "uuid"}, {2951, "_uuid"}, {3220, "pg_lsn"}, {3221, "_pg_lsn"}, {3614, "tsvector"}, {3642, "gtsvector"}, {3615, "tsquery"}, {3734, "regconfig"}, {3769, "regdictionary"}, {3643, "_tsvector"}, {3644, "_gtsvector"}, {3645, "_tsquery"}, {3735, "_regconfig"}, {3770, "_regdictionary"}, {3802, "jsonb"}, {3807, "_jsonb"}, {2970, "txid_snapshot"}, {2949, "_txid_snapshot"}, {3904, "int4range"}, {3905, "_int4range"}, {3906, "numrange"}, {3907, "_numrange"}, {3908, "tsrange"}, {3909, "_tsrange"}, {3910, "tstzrange"}, {3911, "_tstzrange"}, {3912, "daterange"}, {3913, "_daterange"}, {3926, "int8range"}, {3927, "_int8range"}, {2249, "record"}, {2287, "_record"}, {2275, "cstring"}, {2276, "any"}, {2277, "anyarray"}, {2278, "void"}, {2279, "trigger"}, {3838, "event_trigger"}, {2280, "language_handler"}, {2281, "internal"}, {2282, "opaque"}, {2283, "anyelement"}, {2776, "anynonarray"}, {3500, "anyenum"}, {3115, "fdw_handler"}, {325, "index_am_handler"}, {3310, "tsm_handler"}, {3831, "anyrange"}, {51367, "gbtreekey4"}, {51370, "_gbtreekey4"}, {51371, "gbtreekey8"}, {51374, "_gbtreekey8"}, {51375, "gbtreekey16"}, {51378, "_gbtreekey16"}, {51379, "gbtreekey32"}, {51382, "_gbtreekey32"}, {51383, "gbtreekey_var"}, {51386, "_gbtreekey_var"}, {51921, "hstore"}, {51926, "_hstore"}, {52005, "ghstore"}, {52008, "_ghstore"}, } for _, rv := range rowVals { step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat})) steps = append(steps, step) } steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) return steps } type dataRowValue struct { Value interface{} FormatCode int16 } func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow { dr, err := buildDataRow(values, formatCodes) if err != nil { panic(err) } return dr } func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) { dr := &pgproto3.DataRow{ Values: make([][]byte, len(values)), } if len(formatCodes) == 1 { for i := 1; i < len(values); i++ { formatCodes = append(formatCodes, formatCodes[0]) } } for i := range values { switch v := values[i].(type) { case string: values[i] = &pgtype.Text{String: v, Status: pgtype.Present} case int16: values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present} case int32: values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present} case int64: values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present} } } for i := range values { switch formatCodes[i] { case pgproto3.TextFormat: if e, ok := values[i].(pgtype.TextEncoder); ok { buf, err := e.EncodeText(nil, nil) if err != nil { return nil, fmt.Errorf("failed to encode values[%d]", i) } dr.Values[i] = buf } else { return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i) } case pgproto3.BinaryFormat: if e, ok := values[i].(pgtype.BinaryEncoder); ok { buf, err := e.EncodeBinary(nil, nil) if err != nil { return nil, fmt.Errorf("failed to encode values[%d]", i) } dr.Values[i] = buf } else { return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i) } default: return nil, errors.New("unknown FormatCode") } } return dr, nil }