mirror of https://github.com/jackc/pgx.git
94 lines
2.4 KiB
Go
94 lines
2.4 KiB
Go
package pgmock_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/internal/pgmock"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestScript(t *testing.T) {
|
|
script := &pgmock.Script{
|
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
|
}
|
|
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
|
Fields: []pgproto3.FieldDescription{
|
|
pgproto3.FieldDescription{
|
|
Name: []byte("?column?"),
|
|
TableOID: 0,
|
|
TableAttributeNumber: 0,
|
|
DataTypeOID: 23,
|
|
DataTypeSize: 4,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
},
|
|
}))
|
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
|
|
Values: [][]byte{[]byte("42")},
|
|
}))
|
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
|
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
|
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
parts := strings.Split(ln.Addr().String(), ":")
|
|
host := parts[0]
|
|
port := parts[1]
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
pgConn, err := pgconn.Connect(ctx, connStr)
|
|
require.NoError(t, err)
|
|
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "42", string(results[0].Rows[0][0]))
|
|
|
|
pgConn.Close(ctx)
|
|
|
|
assert.NoError(t, <-serverErrChan)
|
|
}
|