mirror of https://github.com/jackc/pgx.git
136 lines
2.6 KiB
Go
136 lines
2.6 KiB
Go
// Package pgmock provides the ability to mock a PostgreSQL server.
|
|
package pgmock
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
)
|
|
|
|
type Step interface {
|
|
Step(*pgproto3.Backend) error
|
|
}
|
|
|
|
type Script struct {
|
|
Steps []Step
|
|
}
|
|
|
|
func (s *Script) Run(backend *pgproto3.Backend) error {
|
|
for _, step := range s.Steps {
|
|
err := step.Step(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Script) Step(backend *pgproto3.Backend) error {
|
|
return s.Run(backend)
|
|
}
|
|
|
|
type expectMessageStep struct {
|
|
want pgproto3.FrontendMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type expectStartupMessageStep struct {
|
|
want *pgproto3.StartupMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.ReceiveStartupMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, false)
|
|
}
|
|
|
|
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, true)
|
|
}
|
|
|
|
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
|
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
|
return &expectStartupMessageStep{want: want, any: any}
|
|
}
|
|
|
|
return &expectMessageStep{want: want, any: any}
|
|
}
|
|
|
|
type sendMessageStep struct {
|
|
msg pgproto3.BackendMessage
|
|
}
|
|
|
|
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
|
return backend.Send(e.msg)
|
|
}
|
|
|
|
func SendMessage(msg pgproto3.BackendMessage) Step {
|
|
return &sendMessageStep{msg: msg}
|
|
}
|
|
|
|
type waitForCloseMessageStep struct{}
|
|
|
|
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
|
for {
|
|
msg, err := backend.Receive()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, ok := msg.(*pgproto3.Terminate); ok {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func WaitForClose() Step {
|
|
return &waitForCloseMessageStep{}
|
|
}
|
|
|
|
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
|
return []Step{
|
|
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
|
SendMessage(&pgproto3.AuthenticationOk{}),
|
|
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}
|
|
}
|