From 479ebdfa1983a0964896928ce1361e06c2639d7a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 May 2017 17:56:54 -0500 Subject: [PATCH] Add basic pgmock support Primarily useful for testing pgx itself. Design is still subject to change. --- pgmock/pgmock.go | 478 ++++++++++++++++++++++++++++++++++++ pgproto3/startup_message.go | 6 +- stdlib/sql_test.go | 103 ++++++++ 3 files changed, 584 insertions(+), 3 deletions(-) create mode 100644 pgmock/pgmock.go diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go new file mode 100644 index 00000000..827fa87d --- /dev/null +++ b/pgmock/pgmock.go @@ -0,0 +1,478 @@ +package pgmock + +import ( + "errors" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type Server struct { + ln net.Listener + controller Controller +} + +func NewServer(controller Controller) (*Server, error) { + ln, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, err + } + + server := &Server{ + ln: ln, + controller: controller, + } + + return server, nil +} + +func (s *Server) Addr() net.Addr { + return s.ln.Addr() +} + +func (s *Server) ServeOne() error { + conn, err := s.ln.Accept() + if err != nil { + return err + } + + backend, err := pgproto3.NewBackend(conn, conn) + if err != nil { + conn.Close() + return err + } + + return s.controller.Serve(backend) +} + +func (s *Server) Close() error { + err := s.ln.Close() + if err != nil { + return err + } + + return nil +} + +type Controller interface { + Serve(backend *pgproto3.Backend) error +} + +type Step interface { + Step(*pgproto3.Backend) error +} + +type Script struct { + Steps []Step +} + +func (s *Script) Run(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Serve(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Step(backend *pgproto3.Backend) error { + return s.Serve(backend) +} + +type expectMessageStep struct { + want pgproto3.FrontendMessage + any bool +} + +func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.Receive() + if err != nil { + return err + } + + if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + return backend.Send(e.msg) +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} + +func PgxInitSteps() []Step { + steps := []Step{ + ExpectMessage(&pgproto3.Parse{ + Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", + }), + ExpectMessage(&pgproto3.Describe{ + ObjectType: 'S', + }), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.ParseComplete{}), + SendMessage(&pgproto3.ParameterDescription{}), + SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: "oid", + TableOID: 1247, + TableAttributeNumber: 65534, + DataTypeOID: 26, + DataTypeSize: 4, + TypeModifier: 4294967295, + Format: 0, + }, + {Name: "typname", + TableOID: 1247, + TableAttributeNumber: 1, + DataTypeOID: 19, + DataTypeSize: 64, + TypeModifier: 4294967295, + Format: 0, + }, + }, + }), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + ExpectMessage(&pgproto3.Bind{ + ParameterFormatCodes: []int16{}, + Parameters: [][]byte{}, + ResultFormatCodes: []int16{1, 1}, + }), + ExpectMessage(&pgproto3.Execute{}), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.BindComplete{}), + } + + rowVals := []struct { + oid pgtype.Oid + name string + }{ + {16, "bool"}, + {17, "bytea"}, + {18, "char"}, + {19, "name"}, + {20, "int8"}, + {21, "int2"}, + {22, "int2vector"}, + {23, "int4"}, + {24, "regproc"}, + {25, "text"}, + {26, "oid"}, + {27, "tid"}, + {28, "xid"}, + {29, "cid"}, + {30, "oidvector"}, + {114, "json"}, + {142, "xml"}, + {143, "_xml"}, + {199, "_json"}, + {194, "pg_node_tree"}, + {32, "pg_ddl_command"}, + {210, "smgr"}, + {600, "point"}, + {601, "lseg"}, + {602, "path"}, + {603, "box"}, + {604, "polygon"}, + {628, "line"}, + {629, "_line"}, + {700, "float4"}, + {701, "float8"}, + {702, "abstime"}, + {703, "reltime"}, + {704, "tinterval"}, + {705, "unknown"}, + {718, "circle"}, + {719, "_circle"}, + {790, "money"}, + {791, "_money"}, + {829, "macaddr"}, + {869, "inet"}, + {650, "cidr"}, + {1000, "_bool"}, + {1001, "_bytea"}, + {1002, "_char"}, + {1003, "_name"}, + {1005, "_int2"}, + {1006, "_int2vector"}, + {1007, "_int4"}, + {1008, "_regproc"}, + {1009, "_text"}, + {1028, "_oid"}, + {1010, "_tid"}, + {1011, "_xid"}, + {1012, "_cid"}, + {1013, "_oidvector"}, + {1014, "_bpchar"}, + {1015, "_varchar"}, + {1016, "_int8"}, + {1017, "_point"}, + {1018, "_lseg"}, + {1019, "_path"}, + {1020, "_box"}, + {1021, "_float4"}, + {1022, "_float8"}, + {1023, "_abstime"}, + {1024, "_reltime"}, + {1025, "_tinterval"}, + {1027, "_polygon"}, + {1033, "aclitem"}, + {1034, "_aclitem"}, + {1040, "_macaddr"}, + {1041, "_inet"}, + {651, "_cidr"}, + {1263, "_cstring"}, + {1042, "bpchar"}, + {1043, "varchar"}, + {1082, "date"}, + {1083, "time"}, + {1114, "timestamp"}, + {1115, "_timestamp"}, + {1182, "_date"}, + {1183, "_time"}, + {1184, "timestamptz"}, + {1185, "_timestamptz"}, + {1186, "interval"}, + {1187, "_interval"}, + {1231, "_numeric"}, + {1266, "timetz"}, + {1270, "_timetz"}, + {1560, "bit"}, + {1561, "_bit"}, + {1562, "varbit"}, + {1563, "_varbit"}, + {1700, "numeric"}, + {1790, "refcursor"}, + {2201, "_refcursor"}, + {2202, "regprocedure"}, + {2203, "regoper"}, + {2204, "regoperator"}, + {2205, "regclass"}, + {2206, "regtype"}, + {4096, "regrole"}, + {4089, "regnamespace"}, + {2207, "_regprocedure"}, + {2208, "_regoper"}, + {2209, "_regoperator"}, + {2210, "_regclass"}, + {2211, "_regtype"}, + {4097, "_regrole"}, + {4090, "_regnamespace"}, + {2950, "uuid"}, + {2951, "_uuid"}, + {3220, "pg_lsn"}, + {3221, "_pg_lsn"}, + {3614, "tsvector"}, + {3642, "gtsvector"}, + {3615, "tsquery"}, + {3734, "regconfig"}, + {3769, "regdictionary"}, + {3643, "_tsvector"}, + {3644, "_gtsvector"}, + {3645, "_tsquery"}, + {3735, "_regconfig"}, + {3770, "_regdictionary"}, + {3802, "jsonb"}, + {3807, "_jsonb"}, + {2970, "txid_snapshot"}, + {2949, "_txid_snapshot"}, + {3904, "int4range"}, + {3905, "_int4range"}, + {3906, "numrange"}, + {3907, "_numrange"}, + {3908, "tsrange"}, + {3909, "_tsrange"}, + {3910, "tstzrange"}, + {3911, "_tstzrange"}, + {3912, "daterange"}, + {3913, "_daterange"}, + {3926, "int8range"}, + {3927, "_int8range"}, + {2249, "record"}, + {2287, "_record"}, + {2275, "cstring"}, + {2276, "any"}, + {2277, "anyarray"}, + {2278, "void"}, + {2279, "trigger"}, + {3838, "event_trigger"}, + {2280, "language_handler"}, + {2281, "internal"}, + {2282, "opaque"}, + {2283, "anyelement"}, + {2776, "anynonarray"}, + {3500, "anyenum"}, + {3115, "fdw_handler"}, + {325, "index_am_handler"}, + {3310, "tsm_handler"}, + {3831, "anyrange"}, + {51367, "gbtreekey4"}, + {51370, "_gbtreekey4"}, + {51371, "gbtreekey8"}, + {51374, "_gbtreekey8"}, + {51375, "gbtreekey16"}, + {51378, "_gbtreekey16"}, + {51379, "gbtreekey32"}, + {51382, "_gbtreekey32"}, + {51383, "gbtreekey_var"}, + {51386, "_gbtreekey_var"}, + {51921, "hstore"}, + {51926, "_hstore"}, + {52005, "ghstore"}, + {52008, "_ghstore"}, + } + + for _, rv := range rowVals { + step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat})) + steps = append(steps, step) + } + + steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) + steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + + return steps +} + +type dataRowValue struct { + Value interface{} + FormatCode int16 +} + +func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow { + dr, err := buildDataRow(values, formatCodes) + if err != nil { + panic(err) + } + + return dr +} + +func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) { + dr := &pgproto3.DataRow{ + Values: make([][]byte, len(values)), + } + + if len(formatCodes) == 1 { + for i := 1; i < len(values); i++ { + formatCodes = append(formatCodes, formatCodes[0]) + } + } + + for i := range values { + switch v := values[i].(type) { + case string: + values[i] = &pgtype.Text{String: v, Status: pgtype.Present} + case int16: + values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present} + case int32: + values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present} + case int64: + values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present} + } + } + + for i := range values { + switch formatCodes[i] { + case pgproto3.TextFormat: + if e, ok := values[i].(pgtype.TextEncoder); ok { + buf, err := e.EncodeText(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i) + } + + case pgproto3.BinaryFormat: + if e, ok := values[i].(pgtype.BinaryEncoder); ok { + buf, err := e.EncodeBinary(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i) + } + default: + return nil, errors.New("unknown FormatCode") + } + } + + return dr, nil +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index ebb804fe..4847d629 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -8,7 +8,7 @@ import ( ) const ( - protocolVersionNumber = 196608 // 3.0 + ProtocolVersionNumber = 196608 // 3.0 sslRequestNumber = 80877103 ) @@ -31,8 +31,8 @@ func (dst *StartupMessage) Decode(src []byte) error { return fmt.Errorf("can't handle ssl connection request") } - if dst.ProtocolVersion != protocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 416a5a7e..4f2484d8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,9 +4,12 @@ import ( "bytes" "context" "database/sql" + "fmt" "testing" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/stdlib" ) @@ -686,6 +689,106 @@ func TestConnBeginTxReadOnly(t *testing.T) { ensureConnValid(t, db) } +func TestBeginTxContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("drop table if exists t") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + _, err = tx.Exec("create table t(id serial)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled { + t.Fatalf("err => %v, want %v", err, context.Canceled) + } + + var n int + err = db.QueryRow("select count(*) from t").Scan(&n) + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" { + t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) + } + + ensureConnValid(t, db) +} + +func acceptStandardPgxConn(backend *pgproto3.Backend) error { + script := pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + err := script.Run(backend) + if err != nil { + return err + } + + typeScript := pgmock.Script{ + Steps: pgmock.PgxInitSteps(), + } + + return typeScript.Run(backend) +} + +func TestBeginTxContextCancelWithDeadConn(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled { + t.Fatalf("err => %v, want %v", err, context.Canceled) + } + + if err := <-errChan; err != nil { + t.Fatalf("mock server err: %v", err) + } +} + func TestAcquireConn(t *testing.T) { db := openDB(t) defer closeDB(t, db)