diff --git a/batch_test.go b/batch_test.go index 32901830..c2e944a1 100644 --- a/batch_test.go +++ b/batch_test.go @@ -803,7 +803,7 @@ func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() diff --git a/conn.go b/conn.go index ba0d9d00..177e21ff 100644 --- a/conn.go +++ b/conn.go @@ -29,13 +29,11 @@ type ConnConfig struct { // to nil to disable automatic prepared statements. BuildStatementCache BuildStatementCacheFunc - // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended - // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client - // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) - // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be - // used by default. The same functionality can be controlled on a per query basis by setting - // QueryExOptions.SimpleProtocol. - PreferSimpleProtocol bool + // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol + // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as + // PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -125,8 +123,9 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { // server. "describe" is primarily useful when the environment does not allow prepared statements such as when // running a connection pooler like PgBouncer. Default: "prepare" // -// prefer_simple_protocol -// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false +// default_query_exec_mode +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { @@ -163,13 +162,22 @@ func ParseConfig(connString string) (*ConnConfig, error) { } } - preferSimpleProtocol := false - if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { - delete(config.RuntimeParams, "prefer_simple_protocol") - if b, err := strconv.ParseBool(s); err == nil { - preferSimpleProtocol = b - } else { - return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") + switch s { + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol + default: + return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) } } @@ -178,7 +186,7 @@ func ParseConfig(connString string) (*ConnConfig, error) { createdByParseConfig: true, LogLevel: LogLevelInfo, BuildStatementCache: buildStatementCache, - PreferSimpleProtocol: preferSimpleProtocol, + DefaultQueryExecMode: defaultQueryExecMode, connString: connString, } @@ -403,13 +411,13 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + simpleProtocol = arg == QueryExecModeSimpleProtocol arguments = arguments[1:] default: break optionLoop @@ -525,8 +533,39 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con return r } -// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. -type QuerySimpleProtocol bool +type QueryExecMode int32 + +const ( + _ QueryExecMode = iota + + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single + // round trip after the statement is cached. This is the default. + QueryExecModeCacheStatement + + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the + // extended protocol. Queries are executed in a single round trip after the description is cached. If the database + // schema is modified or the search_path is changed this may result in undetected result decoding errors. + QueryExecModeCacheDescribe + + // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips + // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even + // when the the database schema is modified concurrently. + QueryExecModeDescribeExec + + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended + // protocol. Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + QueryExecModeExec + + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. + // Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + QueryExecModeSimpleProtocol +) // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 @@ -547,7 +586,7 @@ type QueryResultFormatsByOID map[uint32]int16 func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol optionLoop: for len(args) > 0 { @@ -558,8 +597,8 @@ optionLoop: case QueryResultFormatsByOID: resultFormatsByOID = arg args = args[1:] - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + simpleProtocol = arg == QueryExecModeSimpleProtocol args = args[1:] default: break optionLoop @@ -709,7 +748,7 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol var sb strings.Builder if simpleProtocol { for i, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 3240c954..f5a4319f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -81,7 +81,7 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - connConfig.PreferSimpleProtocol = true + connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn := mustConnect(t, connConfig) defer closeConn(t, conn) @@ -164,23 +164,24 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { require.Equal(t, stmtcache.ModeDescribe, c.Mode()) } -func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) { +func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { t.Parallel() for _, tt := range []struct { connString string - preferSimpleProtocol bool + defaultQueryExecMode pgx.QueryExecMode }{ - {"", false}, - {"prefer_simple_protocol=false", false}, - {"prefer_simple_protocol=0", false}, - {"prefer_simple_protocol=true", true}, - {"prefer_simple_protocol=1", true}, + {"", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe}, + {"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec}, + {"default_query_exec_mode=exec", pgx.QueryExecModeExec}, + {"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol}, } { config, err := pgx.ParseConfig(tt.connString) require.NoError(t, err) - require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString) - require.Empty(t, config.RuntimeParams["prefer_simple_protocol"]) + require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString) + require.Empty(t, config.RuntimeParams["default_query_exec_mode"]) } } @@ -384,7 +385,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { commandTag, err = conn.Exec(ctx, "insert into foo(name) values($1);", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "bar'; drop table foo;--", ) if err != nil { diff --git a/helper_test.go b/helper_test.go index 74c17431..22cc8872 100644 --- a/helper_test.go +++ b/helper_test.go @@ -18,7 +18,7 @@ func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, c config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn, err := pgx.ConnectConfig(context.Background(), config) require.NoError(t, err) defer func() { @@ -130,7 +130,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) diff --git a/large_objects_test.go b/large_objects_test.go index e42a90e7..f86f35e9 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -32,7 +32,7 @@ func TestLargeObjects(t *testing.T) { testLargeObjects(t, ctx, tx) } -func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { +func TestLargeObjectsSimpleProtocol(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -43,7 +43,7 @@ func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { t.Fatal(err) } - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn, err := pgx.ConnectConfig(ctx, config) if err != nil { diff --git a/pgbouncer_test.go b/pgbouncer_test.go index eeae6db4..e80861a0 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -34,7 +34,7 @@ func TestPgbouncerSimpleProtocol(t *testing.T) { config := mustParseConfig(t, connString) config.BuildStatementCache = nil - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol testPgbouncer(t, config, 10, 100) } diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 7b9f9f29..93e1940d 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -168,7 +168,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) diff --git a/query_test.go b/query_test.go index 3728f8a3..a0b75313 100644 --- a/query_test.go +++ b/query_test.go @@ -291,7 +291,7 @@ func TestConnQueryRawValues(t *testing.T) { rows, err := conn.Query( context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, 10, ) require.NoError(t, err) @@ -1385,7 +1385,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1402,7 +1402,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1419,7 +1419,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1436,7 +1436,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bytea", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1453,7 +1453,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1478,7 +1478,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1499,7 +1499,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1520,7 +1520,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1541,7 +1541,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1562,7 +1562,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1583,7 +1583,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1604,7 +1604,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1625,7 +1625,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1646,7 +1646,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1667,7 +1667,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float4[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1688,7 +1688,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1706,7 +1706,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::circle", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, &expected, ).Scan(&actual) if err != nil { @@ -1734,7 +1734,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int8, $2::float8, $3, $4::bytea, $5::text", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) if err != nil { @@ -1765,7 +1765,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1793,7 +1793,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "test", ).Scan(&expected) if err == nil { @@ -1817,7 +1817,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, `\'; drop table users; --`, ).Scan(&expected) if err == nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 07498843..8695e4ad 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -80,7 +80,7 @@ func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, d config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol db := stdlib.OpenDB(*config) defer func() { err := db.Close()