diff --git a/stdlib/sql.go b/stdlib/sql.go index 44e82d98..ddb15ff6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -66,6 +66,7 @@ import ( "math" "math/rand" "reflect" + "sort" "strconv" "strings" "sync" @@ -85,7 +86,13 @@ func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } - sql.Register("pgx", pgxDriver) + + drivers := sql.Drivers() + // if pgx driver was already registered by different pgx major version then we skip registration under the default name. + if i := sort.SearchStrings(sql.Drivers(), "pgx"); len(drivers) >= i || drivers[i] != "pgx" { + sql.Register("pgx", pgxDriver) + } + sql.Register("pgx/v5", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 314f460b..00fbbb72 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -146,9 +146,22 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) { } func TestSQLOpen(t *testing.T) { - db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - closeDB(t, db) + tests := []struct { + driverName string + }{ + {driverName: "pgx"}, + {driverName: "pgx/v5"}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.driverName, func(t *testing.T) { + db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + closeDB(t, db) + }) + } } func TestNormalLifeCycle(t *testing.T) {