pgx/pgmock/pgmock.go
Kelsey Francis 953e08df99 Prefix types in namespaces other than pg_catalog or public
It's possible to define a type (e.g., an enum) with the same name in two
different schemas. When initializing data types after connecting, types
defined within schemas other than pg_catalog or public should be
qualified with their schema name to disambiguate them and ensure all
types with the same base name get added to the map of OID to type.

Prior to this commit, the last type scanned would "win", and all others
with the same name would be missing from the ConnInfo type maps, which
would subsequently cause any PREPARE involving columns of those missing
types to return the error "unknown oid".
2017-09-11 11:29:42 -07:00

512 lines
11 KiB
Go

package pgmock
import (
"io"
"net"
"reflect"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
type Server struct {
ln net.Listener
controller Controller
}
func NewServer(controller Controller) (*Server, error) {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, err
}
server := &Server{
ln: ln,
controller: controller,
}
return server, nil
}
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *Server) ServeOne() error {
conn, err := s.ln.Accept()
if err != nil {
return err
}
defer conn.Close()
s.Close()
backend, err := pgproto3.NewBackend(conn, conn)
if err != nil {
conn.Close()
return err
}
return s.controller.Serve(backend)
}
func (s *Server) Close() error {
err := s.ln.Close()
if err != nil {
return err
}
return nil
}
type Controller interface {
Serve(backend *pgproto3.Backend) error
}
type Step interface {
Step(*pgproto3.Backend) error
}
type Script struct {
Steps []Step
}
func (s *Script) Run(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Serve(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Step(backend *pgproto3.Backend) error {
return s.Serve(backend)
}
type expectMessageStep struct {
want pgproto3.FrontendMessage
any bool
}
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.Receive()
if err != nil {
return err
}
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
type expectStartupMessageStep struct {
want *pgproto3.StartupMessage
any bool
}
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.ReceiveStartupMessage()
if err != nil {
return err
}
if e.any {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
func ExpectMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, false)
}
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, true)
}
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
if want, ok := want.(*pgproto3.StartupMessage); ok {
return &expectStartupMessageStep{want: want, any: any}
}
return &expectMessageStep{want: want, any: any}
}
type sendMessageStep struct {
msg pgproto3.BackendMessage
}
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
return backend.Send(e.msg)
}
func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg}
}
type waitForCloseMessageStep struct{}
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}
func WaitForClose() Step {
return &waitForCloseMessageStep{}
}
func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}
func PgxInitSteps() []Step {
steps := []Step{
ExpectMessage(&pgproto3.Parse{
Query: `select t.oid,
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
else nsp.nspname||'.'||t.typname
end
from pg_type t
left join pg_type base_type on t.typelem=base_type.oid
left join pg_namespace nsp on t.typnamespace=nsp.oid
where (
t.typtype in('b', 'p', 'r', 'e')
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
)`,
}),
ExpectMessage(&pgproto3.Describe{
ObjectType: 'S',
}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.ParseComplete{}),
SendMessage(&pgproto3.ParameterDescription{}),
SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{Name: "oid",
TableOID: 1247,
TableAttributeNumber: 65534,
DataTypeOID: 26,
DataTypeSize: 4,
TypeModifier: 4294967295,
Format: 0,
},
{Name: "typname",
TableOID: 1247,
TableAttributeNumber: 1,
DataTypeOID: 19,
DataTypeSize: 64,
TypeModifier: 4294967295,
Format: 0,
},
},
}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
ExpectMessage(&pgproto3.Bind{
ResultFormatCodes: []int16{1, 1},
}),
ExpectMessage(&pgproto3.Execute{}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.BindComplete{}),
}
rowVals := []struct {
oid pgtype.OID
name string
}{
{16, "bool"},
{17, "bytea"},
{18, "char"},
{19, "name"},
{20, "int8"},
{21, "int2"},
{22, "int2vector"},
{23, "int4"},
{24, "regproc"},
{25, "text"},
{26, "oid"},
{27, "tid"},
{28, "xid"},
{29, "cid"},
{30, "oidvector"},
{114, "json"},
{142, "xml"},
{143, "_xml"},
{199, "_json"},
{194, "pg_node_tree"},
{32, "pg_ddl_command"},
{210, "smgr"},
{600, "point"},
{601, "lseg"},
{602, "path"},
{603, "box"},
{604, "polygon"},
{628, "line"},
{629, "_line"},
{700, "float4"},
{701, "float8"},
{702, "abstime"},
{703, "reltime"},
{704, "tinterval"},
{705, "unknown"},
{718, "circle"},
{719, "_circle"},
{790, "money"},
{791, "_money"},
{829, "macaddr"},
{869, "inet"},
{650, "cidr"},
{1000, "_bool"},
{1001, "_bytea"},
{1002, "_char"},
{1003, "_name"},
{1005, "_int2"},
{1006, "_int2vector"},
{1007, "_int4"},
{1008, "_regproc"},
{1009, "_text"},
{1028, "_oid"},
{1010, "_tid"},
{1011, "_xid"},
{1012, "_cid"},
{1013, "_oidvector"},
{1014, "_bpchar"},
{1015, "_varchar"},
{1016, "_int8"},
{1017, "_point"},
{1018, "_lseg"},
{1019, "_path"},
{1020, "_box"},
{1021, "_float4"},
{1022, "_float8"},
{1023, "_abstime"},
{1024, "_reltime"},
{1025, "_tinterval"},
{1027, "_polygon"},
{1033, "aclitem"},
{1034, "_aclitem"},
{1040, "_macaddr"},
{1041, "_inet"},
{651, "_cidr"},
{1263, "_cstring"},
{1042, "bpchar"},
{1043, "varchar"},
{1082, "date"},
{1083, "time"},
{1114, "timestamp"},
{1115, "_timestamp"},
{1182, "_date"},
{1183, "_time"},
{1184, "timestamptz"},
{1185, "_timestamptz"},
{1186, "interval"},
{1187, "_interval"},
{1231, "_numeric"},
{1266, "timetz"},
{1270, "_timetz"},
{1560, "bit"},
{1561, "_bit"},
{1562, "varbit"},
{1563, "_varbit"},
{1700, "numeric"},
{1790, "refcursor"},
{2201, "_refcursor"},
{2202, "regprocedure"},
{2203, "regoper"},
{2204, "regoperator"},
{2205, "regclass"},
{2206, "regtype"},
{4096, "regrole"},
{4089, "regnamespace"},
{2207, "_regprocedure"},
{2208, "_regoper"},
{2209, "_regoperator"},
{2210, "_regclass"},
{2211, "_regtype"},
{4097, "_regrole"},
{4090, "_regnamespace"},
{2950, "uuid"},
{2951, "_uuid"},
{3220, "pg_lsn"},
{3221, "_pg_lsn"},
{3614, "tsvector"},
{3642, "gtsvector"},
{3615, "tsquery"},
{3734, "regconfig"},
{3769, "regdictionary"},
{3643, "_tsvector"},
{3644, "_gtsvector"},
{3645, "_tsquery"},
{3735, "_regconfig"},
{3770, "_regdictionary"},
{3802, "jsonb"},
{3807, "_jsonb"},
{2970, "txid_snapshot"},
{2949, "_txid_snapshot"},
{3904, "int4range"},
{3905, "_int4range"},
{3906, "numrange"},
{3907, "_numrange"},
{3908, "tsrange"},
{3909, "_tsrange"},
{3910, "tstzrange"},
{3911, "_tstzrange"},
{3912, "daterange"},
{3913, "_daterange"},
{3926, "int8range"},
{3927, "_int8range"},
{2249, "record"},
{2287, "_record"},
{2275, "cstring"},
{2276, "any"},
{2277, "anyarray"},
{2278, "void"},
{2279, "trigger"},
{3838, "event_trigger"},
{2280, "language_handler"},
{2281, "internal"},
{2282, "opaque"},
{2283, "anyelement"},
{2776, "anynonarray"},
{3500, "anyenum"},
{3115, "fdw_handler"},
{325, "index_am_handler"},
{3310, "tsm_handler"},
{3831, "anyrange"},
{51367, "gbtreekey4"},
{51370, "_gbtreekey4"},
{51371, "gbtreekey8"},
{51374, "_gbtreekey8"},
{51375, "gbtreekey16"},
{51378, "_gbtreekey16"},
{51379, "gbtreekey32"},
{51382, "_gbtreekey32"},
{51383, "gbtreekey_var"},
{51386, "_gbtreekey_var"},
{51921, "hstore"},
{51926, "_hstore"},
{52005, "ghstore"},
{52008, "_ghstore"},
}
for _, rv := range rowVals {
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
steps = append(steps, step)
}
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return steps
}
type dataRowValue struct {
Value interface{}
FormatCode int16
}
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
dr, err := buildDataRow(values, formatCodes)
if err != nil {
panic(err)
}
return dr
}
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
dr := &pgproto3.DataRow{
Values: make([][]byte, len(values)),
}
if len(formatCodes) == 1 {
for i := 1; i < len(values); i++ {
formatCodes = append(formatCodes, formatCodes[0])
}
}
for i := range values {
switch v := values[i].(type) {
case string:
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
case int16:
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
case int32:
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
case int64:
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
}
}
for i := range values {
switch formatCodes[i] {
case pgproto3.TextFormat:
if e, ok := values[i].(pgtype.TextEncoder); ok {
buf, err := e.EncodeText(nil, nil)
if err != nil {
return nil, errors.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, errors.Errorf("values[%d] does not implement TextExcoder", i)
}
case pgproto3.BinaryFormat:
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
buf, err := e.EncodeBinary(nil, nil)
if err != nil {
return nil, errors.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i)
}
default:
return nil, errors.New("unknown FormatCode")
}
}
return dr, nil
}