mirror of https://github.com/jackc/pgx.git
Add basic pgmock support
Primarily useful for testing pgx itself. Design is still subject to change.batch-wip
parent
413871a897
commit
479ebdfa19
|
@ -0,0 +1,478 @@
|
|||
package pgmock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ln net.Listener
|
||||
controller Controller
|
||||
}
|
||||
|
||||
func NewServer(controller Controller) (*Server, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
ln: ln,
|
||||
controller: controller,
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) Addr() net.Addr {
|
||||
return s.ln.Addr()
|
||||
}
|
||||
|
||||
func (s *Server) ServeOne() error {
|
||||
conn, err := s.ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend, err := pgproto3.NewBackend(conn, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return s.controller.Serve(backend)
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
err := s.ln.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Controller interface {
|
||||
Serve(backend *pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Step interface {
|
||||
Step(*pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Script struct {
|
||||
Steps []Step
|
||||
}
|
||||
|
||||
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Serve(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||
return s.Serve(backend)
|
||||
}
|
||||
|
||||
type expectMessageStep struct {
|
||||
want pgproto3.FrontendMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type expectStartupMessageStep struct {
|
||||
want *pgproto3.StartupMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, false)
|
||||
}
|
||||
|
||||
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, true)
|
||||
}
|
||||
|
||||
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||
return &expectStartupMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
return &expectMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
type sendMessageStep struct {
|
||||
msg pgproto3.BackendMessage
|
||||
}
|
||||
|
||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
return backend.Send(e.msg)
|
||||
}
|
||||
|
||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
|
||||
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
}
|
||||
}
|
||||
|
||||
func PgxInitSteps() []Step {
|
||||
steps := []Step{
|
||||
ExpectMessage(&pgproto3.Parse{
|
||||
Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)",
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Describe{
|
||||
ObjectType: 'S',
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Sync{}),
|
||||
SendMessage(&pgproto3.ParseComplete{}),
|
||||
SendMessage(&pgproto3.ParameterDescription{}),
|
||||
SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
{Name: "oid",
|
||||
TableOID: 1247,
|
||||
TableAttributeNumber: 65534,
|
||||
DataTypeOID: 26,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: 4294967295,
|
||||
Format: 0,
|
||||
},
|
||||
{Name: "typname",
|
||||
TableOID: 1247,
|
||||
TableAttributeNumber: 1,
|
||||
DataTypeOID: 19,
|
||||
DataTypeSize: 64,
|
||||
TypeModifier: 4294967295,
|
||||
Format: 0,
|
||||
},
|
||||
},
|
||||
}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
ExpectMessage(&pgproto3.Bind{
|
||||
ParameterFormatCodes: []int16{},
|
||||
Parameters: [][]byte{},
|
||||
ResultFormatCodes: []int16{1, 1},
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Execute{}),
|
||||
ExpectMessage(&pgproto3.Sync{}),
|
||||
SendMessage(&pgproto3.BindComplete{}),
|
||||
}
|
||||
|
||||
rowVals := []struct {
|
||||
oid pgtype.Oid
|
||||
name string
|
||||
}{
|
||||
{16, "bool"},
|
||||
{17, "bytea"},
|
||||
{18, "char"},
|
||||
{19, "name"},
|
||||
{20, "int8"},
|
||||
{21, "int2"},
|
||||
{22, "int2vector"},
|
||||
{23, "int4"},
|
||||
{24, "regproc"},
|
||||
{25, "text"},
|
||||
{26, "oid"},
|
||||
{27, "tid"},
|
||||
{28, "xid"},
|
||||
{29, "cid"},
|
||||
{30, "oidvector"},
|
||||
{114, "json"},
|
||||
{142, "xml"},
|
||||
{143, "_xml"},
|
||||
{199, "_json"},
|
||||
{194, "pg_node_tree"},
|
||||
{32, "pg_ddl_command"},
|
||||
{210, "smgr"},
|
||||
{600, "point"},
|
||||
{601, "lseg"},
|
||||
{602, "path"},
|
||||
{603, "box"},
|
||||
{604, "polygon"},
|
||||
{628, "line"},
|
||||
{629, "_line"},
|
||||
{700, "float4"},
|
||||
{701, "float8"},
|
||||
{702, "abstime"},
|
||||
{703, "reltime"},
|
||||
{704, "tinterval"},
|
||||
{705, "unknown"},
|
||||
{718, "circle"},
|
||||
{719, "_circle"},
|
||||
{790, "money"},
|
||||
{791, "_money"},
|
||||
{829, "macaddr"},
|
||||
{869, "inet"},
|
||||
{650, "cidr"},
|
||||
{1000, "_bool"},
|
||||
{1001, "_bytea"},
|
||||
{1002, "_char"},
|
||||
{1003, "_name"},
|
||||
{1005, "_int2"},
|
||||
{1006, "_int2vector"},
|
||||
{1007, "_int4"},
|
||||
{1008, "_regproc"},
|
||||
{1009, "_text"},
|
||||
{1028, "_oid"},
|
||||
{1010, "_tid"},
|
||||
{1011, "_xid"},
|
||||
{1012, "_cid"},
|
||||
{1013, "_oidvector"},
|
||||
{1014, "_bpchar"},
|
||||
{1015, "_varchar"},
|
||||
{1016, "_int8"},
|
||||
{1017, "_point"},
|
||||
{1018, "_lseg"},
|
||||
{1019, "_path"},
|
||||
{1020, "_box"},
|
||||
{1021, "_float4"},
|
||||
{1022, "_float8"},
|
||||
{1023, "_abstime"},
|
||||
{1024, "_reltime"},
|
||||
{1025, "_tinterval"},
|
||||
{1027, "_polygon"},
|
||||
{1033, "aclitem"},
|
||||
{1034, "_aclitem"},
|
||||
{1040, "_macaddr"},
|
||||
{1041, "_inet"},
|
||||
{651, "_cidr"},
|
||||
{1263, "_cstring"},
|
||||
{1042, "bpchar"},
|
||||
{1043, "varchar"},
|
||||
{1082, "date"},
|
||||
{1083, "time"},
|
||||
{1114, "timestamp"},
|
||||
{1115, "_timestamp"},
|
||||
{1182, "_date"},
|
||||
{1183, "_time"},
|
||||
{1184, "timestamptz"},
|
||||
{1185, "_timestamptz"},
|
||||
{1186, "interval"},
|
||||
{1187, "_interval"},
|
||||
{1231, "_numeric"},
|
||||
{1266, "timetz"},
|
||||
{1270, "_timetz"},
|
||||
{1560, "bit"},
|
||||
{1561, "_bit"},
|
||||
{1562, "varbit"},
|
||||
{1563, "_varbit"},
|
||||
{1700, "numeric"},
|
||||
{1790, "refcursor"},
|
||||
{2201, "_refcursor"},
|
||||
{2202, "regprocedure"},
|
||||
{2203, "regoper"},
|
||||
{2204, "regoperator"},
|
||||
{2205, "regclass"},
|
||||
{2206, "regtype"},
|
||||
{4096, "regrole"},
|
||||
{4089, "regnamespace"},
|
||||
{2207, "_regprocedure"},
|
||||
{2208, "_regoper"},
|
||||
{2209, "_regoperator"},
|
||||
{2210, "_regclass"},
|
||||
{2211, "_regtype"},
|
||||
{4097, "_regrole"},
|
||||
{4090, "_regnamespace"},
|
||||
{2950, "uuid"},
|
||||
{2951, "_uuid"},
|
||||
{3220, "pg_lsn"},
|
||||
{3221, "_pg_lsn"},
|
||||
{3614, "tsvector"},
|
||||
{3642, "gtsvector"},
|
||||
{3615, "tsquery"},
|
||||
{3734, "regconfig"},
|
||||
{3769, "regdictionary"},
|
||||
{3643, "_tsvector"},
|
||||
{3644, "_gtsvector"},
|
||||
{3645, "_tsquery"},
|
||||
{3735, "_regconfig"},
|
||||
{3770, "_regdictionary"},
|
||||
{3802, "jsonb"},
|
||||
{3807, "_jsonb"},
|
||||
{2970, "txid_snapshot"},
|
||||
{2949, "_txid_snapshot"},
|
||||
{3904, "int4range"},
|
||||
{3905, "_int4range"},
|
||||
{3906, "numrange"},
|
||||
{3907, "_numrange"},
|
||||
{3908, "tsrange"},
|
||||
{3909, "_tsrange"},
|
||||
{3910, "tstzrange"},
|
||||
{3911, "_tstzrange"},
|
||||
{3912, "daterange"},
|
||||
{3913, "_daterange"},
|
||||
{3926, "int8range"},
|
||||
{3927, "_int8range"},
|
||||
{2249, "record"},
|
||||
{2287, "_record"},
|
||||
{2275, "cstring"},
|
||||
{2276, "any"},
|
||||
{2277, "anyarray"},
|
||||
{2278, "void"},
|
||||
{2279, "trigger"},
|
||||
{3838, "event_trigger"},
|
||||
{2280, "language_handler"},
|
||||
{2281, "internal"},
|
||||
{2282, "opaque"},
|
||||
{2283, "anyelement"},
|
||||
{2776, "anynonarray"},
|
||||
{3500, "anyenum"},
|
||||
{3115, "fdw_handler"},
|
||||
{325, "index_am_handler"},
|
||||
{3310, "tsm_handler"},
|
||||
{3831, "anyrange"},
|
||||
{51367, "gbtreekey4"},
|
||||
{51370, "_gbtreekey4"},
|
||||
{51371, "gbtreekey8"},
|
||||
{51374, "_gbtreekey8"},
|
||||
{51375, "gbtreekey16"},
|
||||
{51378, "_gbtreekey16"},
|
||||
{51379, "gbtreekey32"},
|
||||
{51382, "_gbtreekey32"},
|
||||
{51383, "gbtreekey_var"},
|
||||
{51386, "_gbtreekey_var"},
|
||||
{51921, "hstore"},
|
||||
{51926, "_hstore"},
|
||||
{52005, "ghstore"},
|
||||
{52008, "_ghstore"},
|
||||
}
|
||||
|
||||
for _, rv := range rowVals {
|
||||
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
|
||||
steps = append(steps, step)
|
||||
}
|
||||
|
||||
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
|
||||
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||
|
||||
return steps
|
||||
}
|
||||
|
||||
type dataRowValue struct {
|
||||
Value interface{}
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
|
||||
dr, err := buildDataRow(values, formatCodes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dr
|
||||
}
|
||||
|
||||
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
|
||||
dr := &pgproto3.DataRow{
|
||||
Values: make([][]byte, len(values)),
|
||||
}
|
||||
|
||||
if len(formatCodes) == 1 {
|
||||
for i := 1; i < len(values); i++ {
|
||||
formatCodes = append(formatCodes, formatCodes[0])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
switch v := values[i].(type) {
|
||||
case string:
|
||||
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
|
||||
case int16:
|
||||
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
|
||||
case int32:
|
||||
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
|
||||
case int64:
|
||||
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
switch formatCodes[i] {
|
||||
case pgproto3.TextFormat:
|
||||
if e, ok := values[i].(pgtype.TextEncoder); ok {
|
||||
buf, err := e.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode values[%d]", i)
|
||||
}
|
||||
dr.Values[i] = buf
|
||||
} else {
|
||||
return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i)
|
||||
}
|
||||
|
||||
case pgproto3.BinaryFormat:
|
||||
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
|
||||
buf, err := e.EncodeBinary(nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode values[%d]", i)
|
||||
}
|
||||
dr.Values[i] = buf
|
||||
} else {
|
||||
return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i)
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unknown FormatCode")
|
||||
}
|
||||
}
|
||||
|
||||
return dr, nil
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
protocolVersionNumber = 196608 // 3.0
|
||||
ProtocolVersionNumber = 196608 // 3.0
|
||||
sslRequestNumber = 80877103
|
||||
)
|
||||
|
||||
|
@ -31,8 +31,8 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
|||
return fmt.Errorf("can't handle ssl connection request")
|
||||
}
|
||||
|
||||
if dst.ProtocolVersion != protocolVersionNumber {
|
||||
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion)
|
||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||
}
|
||||
|
||||
dst.Parameters = make(map[string]string)
|
||||
|
|
|
@ -4,9 +4,12 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgmock"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/stdlib"
|
||||
)
|
||||
|
||||
|
@ -686,6 +689,106 @@ func TestConnBeginTxReadOnly(t *testing.T) {
|
|||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestBeginTxContextCancel(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
||||
_, err := db.Exec("drop table if exists t")
|
||||
if err != nil {
|
||||
t.Fatalf("db.Exec failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("create table t(id serial)")
|
||||
if err != nil {
|
||||
t.Fatalf("tx.Exec failed: %v", err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
err = tx.Commit()
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
var n int
|
||||
err = db.QueryRow("select count(*) from t").Scan(&n)
|
||||
if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" {
|
||||
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func acceptStandardPgxConn(backend *pgproto3.Backend) error {
|
||||
script := pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
|
||||
err := script.Run(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
typeScript := pgmock.Script{
|
||||
Steps: pgmock.PgxInitSteps(),
|
||||
}
|
||||
|
||||
return typeScript.Run(backend)
|
||||
}
|
||||
|
||||
func TestBeginTxContextCancelWithDeadConn(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
||||
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx failed: %v", err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
err = tx.Commit()
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Fatalf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireConn(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
|
Loading…
Reference in New Issue