mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 21:25:53 +00:00
It's possible to define a type (e.g., an enum) with the same name in two different schemas. When initializing data types after connecting, types defined within schemas other than pg_catalog or public should be qualified with their schema name to disambiguate them and ensure all types with the same base name get added to the map of OID to type. Prior to this commit, the last type scanned would "win", and all others with the same name would be missing from the ConnInfo type maps, which would subsequently cause any PREPARE involving columns of those missing types to return the error "unknown oid".
512 lines
11 KiB
Go
512 lines
11 KiB
Go
package pgmock
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"reflect"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/jackc/pgx/pgproto3"
|
|
"github.com/jackc/pgx/pgtype"
|
|
)
|
|
|
|
type Server struct {
|
|
ln net.Listener
|
|
controller Controller
|
|
}
|
|
|
|
func NewServer(controller Controller) (*Server, error) {
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
server := &Server{
|
|
ln: ln,
|
|
controller: controller,
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
func (s *Server) Addr() net.Addr {
|
|
return s.ln.Addr()
|
|
}
|
|
|
|
func (s *Server) ServeOne() error {
|
|
conn, err := s.ln.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
s.Close()
|
|
|
|
backend, err := pgproto3.NewBackend(conn, conn)
|
|
if err != nil {
|
|
conn.Close()
|
|
return err
|
|
}
|
|
|
|
return s.controller.Serve(backend)
|
|
}
|
|
|
|
func (s *Server) Close() error {
|
|
err := s.ln.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type Controller interface {
|
|
Serve(backend *pgproto3.Backend) error
|
|
}
|
|
|
|
type Step interface {
|
|
Step(*pgproto3.Backend) error
|
|
}
|
|
|
|
type Script struct {
|
|
Steps []Step
|
|
}
|
|
|
|
func (s *Script) Run(backend *pgproto3.Backend) error {
|
|
for _, step := range s.Steps {
|
|
err := step.Step(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Script) Serve(backend *pgproto3.Backend) error {
|
|
for _, step := range s.Steps {
|
|
err := step.Step(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Script) Step(backend *pgproto3.Backend) error {
|
|
return s.Serve(backend)
|
|
}
|
|
|
|
type expectMessageStep struct {
|
|
want pgproto3.FrontendMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type expectStartupMessageStep struct {
|
|
want *pgproto3.StartupMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.ReceiveStartupMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, false)
|
|
}
|
|
|
|
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, true)
|
|
}
|
|
|
|
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
|
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
|
return &expectStartupMessageStep{want: want, any: any}
|
|
}
|
|
|
|
return &expectMessageStep{want: want, any: any}
|
|
}
|
|
|
|
type sendMessageStep struct {
|
|
msg pgproto3.BackendMessage
|
|
}
|
|
|
|
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
|
return backend.Send(e.msg)
|
|
}
|
|
|
|
func SendMessage(msg pgproto3.BackendMessage) Step {
|
|
return &sendMessageStep{msg: msg}
|
|
}
|
|
|
|
type waitForCloseMessageStep struct{}
|
|
|
|
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
|
for {
|
|
msg, err := backend.Receive()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, ok := msg.(*pgproto3.Terminate); ok {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func WaitForClose() Step {
|
|
return &waitForCloseMessageStep{}
|
|
}
|
|
|
|
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
|
return []Step{
|
|
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
|
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
|
|
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}
|
|
}
|
|
|
|
func PgxInitSteps() []Step {
|
|
steps := []Step{
|
|
ExpectMessage(&pgproto3.Parse{
|
|
Query: `select t.oid,
|
|
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
|
|
else nsp.nspname||'.'||t.typname
|
|
end
|
|
from pg_type t
|
|
left join pg_type base_type on t.typelem=base_type.oid
|
|
left join pg_namespace nsp on t.typnamespace=nsp.oid
|
|
where (
|
|
t.typtype in('b', 'p', 'r', 'e')
|
|
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
|
|
)`,
|
|
}),
|
|
ExpectMessage(&pgproto3.Describe{
|
|
ObjectType: 'S',
|
|
}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.ParseComplete{}),
|
|
SendMessage(&pgproto3.ParameterDescription{}),
|
|
SendMessage(&pgproto3.RowDescription{
|
|
Fields: []pgproto3.FieldDescription{
|
|
{Name: "oid",
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 65534,
|
|
DataTypeOID: 26,
|
|
DataTypeSize: 4,
|
|
TypeModifier: 4294967295,
|
|
Format: 0,
|
|
},
|
|
{Name: "typname",
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 1,
|
|
DataTypeOID: 19,
|
|
DataTypeSize: 64,
|
|
TypeModifier: 4294967295,
|
|
Format: 0,
|
|
},
|
|
},
|
|
}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
ExpectMessage(&pgproto3.Bind{
|
|
ResultFormatCodes: []int16{1, 1},
|
|
}),
|
|
ExpectMessage(&pgproto3.Execute{}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.BindComplete{}),
|
|
}
|
|
|
|
rowVals := []struct {
|
|
oid pgtype.OID
|
|
name string
|
|
}{
|
|
{16, "bool"},
|
|
{17, "bytea"},
|
|
{18, "char"},
|
|
{19, "name"},
|
|
{20, "int8"},
|
|
{21, "int2"},
|
|
{22, "int2vector"},
|
|
{23, "int4"},
|
|
{24, "regproc"},
|
|
{25, "text"},
|
|
{26, "oid"},
|
|
{27, "tid"},
|
|
{28, "xid"},
|
|
{29, "cid"},
|
|
{30, "oidvector"},
|
|
{114, "json"},
|
|
{142, "xml"},
|
|
{143, "_xml"},
|
|
{199, "_json"},
|
|
{194, "pg_node_tree"},
|
|
{32, "pg_ddl_command"},
|
|
{210, "smgr"},
|
|
{600, "point"},
|
|
{601, "lseg"},
|
|
{602, "path"},
|
|
{603, "box"},
|
|
{604, "polygon"},
|
|
{628, "line"},
|
|
{629, "_line"},
|
|
{700, "float4"},
|
|
{701, "float8"},
|
|
{702, "abstime"},
|
|
{703, "reltime"},
|
|
{704, "tinterval"},
|
|
{705, "unknown"},
|
|
{718, "circle"},
|
|
{719, "_circle"},
|
|
{790, "money"},
|
|
{791, "_money"},
|
|
{829, "macaddr"},
|
|
{869, "inet"},
|
|
{650, "cidr"},
|
|
{1000, "_bool"},
|
|
{1001, "_bytea"},
|
|
{1002, "_char"},
|
|
{1003, "_name"},
|
|
{1005, "_int2"},
|
|
{1006, "_int2vector"},
|
|
{1007, "_int4"},
|
|
{1008, "_regproc"},
|
|
{1009, "_text"},
|
|
{1028, "_oid"},
|
|
{1010, "_tid"},
|
|
{1011, "_xid"},
|
|
{1012, "_cid"},
|
|
{1013, "_oidvector"},
|
|
{1014, "_bpchar"},
|
|
{1015, "_varchar"},
|
|
{1016, "_int8"},
|
|
{1017, "_point"},
|
|
{1018, "_lseg"},
|
|
{1019, "_path"},
|
|
{1020, "_box"},
|
|
{1021, "_float4"},
|
|
{1022, "_float8"},
|
|
{1023, "_abstime"},
|
|
{1024, "_reltime"},
|
|
{1025, "_tinterval"},
|
|
{1027, "_polygon"},
|
|
{1033, "aclitem"},
|
|
{1034, "_aclitem"},
|
|
{1040, "_macaddr"},
|
|
{1041, "_inet"},
|
|
{651, "_cidr"},
|
|
{1263, "_cstring"},
|
|
{1042, "bpchar"},
|
|
{1043, "varchar"},
|
|
{1082, "date"},
|
|
{1083, "time"},
|
|
{1114, "timestamp"},
|
|
{1115, "_timestamp"},
|
|
{1182, "_date"},
|
|
{1183, "_time"},
|
|
{1184, "timestamptz"},
|
|
{1185, "_timestamptz"},
|
|
{1186, "interval"},
|
|
{1187, "_interval"},
|
|
{1231, "_numeric"},
|
|
{1266, "timetz"},
|
|
{1270, "_timetz"},
|
|
{1560, "bit"},
|
|
{1561, "_bit"},
|
|
{1562, "varbit"},
|
|
{1563, "_varbit"},
|
|
{1700, "numeric"},
|
|
{1790, "refcursor"},
|
|
{2201, "_refcursor"},
|
|
{2202, "regprocedure"},
|
|
{2203, "regoper"},
|
|
{2204, "regoperator"},
|
|
{2205, "regclass"},
|
|
{2206, "regtype"},
|
|
{4096, "regrole"},
|
|
{4089, "regnamespace"},
|
|
{2207, "_regprocedure"},
|
|
{2208, "_regoper"},
|
|
{2209, "_regoperator"},
|
|
{2210, "_regclass"},
|
|
{2211, "_regtype"},
|
|
{4097, "_regrole"},
|
|
{4090, "_regnamespace"},
|
|
{2950, "uuid"},
|
|
{2951, "_uuid"},
|
|
{3220, "pg_lsn"},
|
|
{3221, "_pg_lsn"},
|
|
{3614, "tsvector"},
|
|
{3642, "gtsvector"},
|
|
{3615, "tsquery"},
|
|
{3734, "regconfig"},
|
|
{3769, "regdictionary"},
|
|
{3643, "_tsvector"},
|
|
{3644, "_gtsvector"},
|
|
{3645, "_tsquery"},
|
|
{3735, "_regconfig"},
|
|
{3770, "_regdictionary"},
|
|
{3802, "jsonb"},
|
|
{3807, "_jsonb"},
|
|
{2970, "txid_snapshot"},
|
|
{2949, "_txid_snapshot"},
|
|
{3904, "int4range"},
|
|
{3905, "_int4range"},
|
|
{3906, "numrange"},
|
|
{3907, "_numrange"},
|
|
{3908, "tsrange"},
|
|
{3909, "_tsrange"},
|
|
{3910, "tstzrange"},
|
|
{3911, "_tstzrange"},
|
|
{3912, "daterange"},
|
|
{3913, "_daterange"},
|
|
{3926, "int8range"},
|
|
{3927, "_int8range"},
|
|
{2249, "record"},
|
|
{2287, "_record"},
|
|
{2275, "cstring"},
|
|
{2276, "any"},
|
|
{2277, "anyarray"},
|
|
{2278, "void"},
|
|
{2279, "trigger"},
|
|
{3838, "event_trigger"},
|
|
{2280, "language_handler"},
|
|
{2281, "internal"},
|
|
{2282, "opaque"},
|
|
{2283, "anyelement"},
|
|
{2776, "anynonarray"},
|
|
{3500, "anyenum"},
|
|
{3115, "fdw_handler"},
|
|
{325, "index_am_handler"},
|
|
{3310, "tsm_handler"},
|
|
{3831, "anyrange"},
|
|
{51367, "gbtreekey4"},
|
|
{51370, "_gbtreekey4"},
|
|
{51371, "gbtreekey8"},
|
|
{51374, "_gbtreekey8"},
|
|
{51375, "gbtreekey16"},
|
|
{51378, "_gbtreekey16"},
|
|
{51379, "gbtreekey32"},
|
|
{51382, "_gbtreekey32"},
|
|
{51383, "gbtreekey_var"},
|
|
{51386, "_gbtreekey_var"},
|
|
{51921, "hstore"},
|
|
{51926, "_hstore"},
|
|
{52005, "ghstore"},
|
|
{52008, "_ghstore"},
|
|
}
|
|
|
|
for _, rv := range rowVals {
|
|
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
|
|
steps = append(steps, step)
|
|
}
|
|
|
|
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
|
|
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
|
|
|
return steps
|
|
}
|
|
|
|
type dataRowValue struct {
|
|
Value interface{}
|
|
FormatCode int16
|
|
}
|
|
|
|
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
|
|
dr, err := buildDataRow(values, formatCodes)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return dr
|
|
}
|
|
|
|
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
|
|
dr := &pgproto3.DataRow{
|
|
Values: make([][]byte, len(values)),
|
|
}
|
|
|
|
if len(formatCodes) == 1 {
|
|
for i := 1; i < len(values); i++ {
|
|
formatCodes = append(formatCodes, formatCodes[0])
|
|
}
|
|
}
|
|
|
|
for i := range values {
|
|
switch v := values[i].(type) {
|
|
case string:
|
|
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
|
|
case int16:
|
|
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
|
|
case int32:
|
|
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
|
|
case int64:
|
|
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
|
|
}
|
|
}
|
|
|
|
for i := range values {
|
|
switch formatCodes[i] {
|
|
case pgproto3.TextFormat:
|
|
if e, ok := values[i].(pgtype.TextEncoder); ok {
|
|
buf, err := e.EncodeText(nil, nil)
|
|
if err != nil {
|
|
return nil, errors.Errorf("failed to encode values[%d]", i)
|
|
}
|
|
dr.Values[i] = buf
|
|
} else {
|
|
return nil, errors.Errorf("values[%d] does not implement TextExcoder", i)
|
|
}
|
|
|
|
case pgproto3.BinaryFormat:
|
|
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
|
|
buf, err := e.EncodeBinary(nil, nil)
|
|
if err != nil {
|
|
return nil, errors.Errorf("failed to encode values[%d]", i)
|
|
}
|
|
dr.Values[i] = buf
|
|
} else {
|
|
return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i)
|
|
}
|
|
default:
|
|
return nil, errors.New("unknown FormatCode")
|
|
}
|
|
}
|
|
|
|
return dr, nil
|
|
}
|