diff --git a/conn.go b/conn.go index a689bfee..03a2221d 100644 --- a/conn.go +++ b/conn.go @@ -470,14 +470,14 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) + case string: + wbuf.WriteInt16(TextFormatCode) default: switch oid { case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid: wbuf.WriteInt16(BinaryFormatCode) - case TextOid, VarcharOid, DateOid, TimestampOid: - wbuf.WriteInt16(TextFormatCode) default: - return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg)) + wbuf.WriteInt16(TextFormatCode) } } } @@ -492,6 +492,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: err = arg.Encode(wbuf, oid) + case string: + err = encodeText(wbuf, arguments[i]) default: switch oid { case BoolOid: @@ -517,7 +519,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) default: - return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg)) + return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg)) } } if err != nil { diff --git a/query_test.go b/query_test.go index 3d8faa19..e5a8fab4 100644 --- a/query_test.go +++ b/query_test.go @@ -272,14 +272,19 @@ func TestQueryEncodeError(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.Query("select $1::integer", "wrong") - switch { - case err == nil: - t.Error("Expected transcode error to return error, but it didn't") - case err.Error() == "Expected integer representable in int32, received string wrong": - // Correct behavior - default: - t.Errorf("Expected transcode error, received %v", err) + rows, err := conn.Query("select $1::integer", "wrong") + if err != nil { + t.Errorf("conn.Query failure: %v", err) + } + defer rows.Close() + + rows.Next() + + if rows.Err() == nil { + t.Error("Expected rows.Err() to return error, but it didn't") + } + if rows.Err().Error() != `ERROR: invalid input syntax for integer: "wrong" (SQLSTATE 22P02)` { + t.Error("Expected rows.Err() to return different error:", rows.Err()) } } @@ -399,6 +404,29 @@ func TestQueryRowCoreBytea(t *testing.T) { ensureConnValid(t, conn) } +func TestQueryRowUnknownType(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::inet" + expected := "127.0.0.1" + var actual string + + err := conn.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + + } + + ensureConnValid(t, conn) +} + func TestQueryRowErrors(t *testing.T) { t.Parallel() diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 670224f2..2abe50d1 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -295,6 +295,24 @@ func TestConnQueryFailure(t *testing.T) { } } +func TestConnQueryRowUnknownType(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + sql := "select $1::inet" + expected := "127.0.0.1" + var actual string + + err := db.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + } +} + func TestTransactionLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db)