diff --git a/conn.go b/conn.go index 2c38fd2c..26d99169 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -593,6 +593,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) + case BoolArrayOid: + err = encodeBoolArray(wbuf, arguments[i]) case Int2ArrayOid: err = encodeInt2Array(wbuf, arguments[i]) case Int4ArrayOid: diff --git a/conn_test.go b/conn_test.go index c6264f77..9aed24e1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -515,3 +515,19 @@ func TestCommandTag(t *testing.T) { } } } + +func TestInsertBoolArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} diff --git a/values.go b/values.go index 00091513..33515b41 100644 --- a/values.go +++ b/values.go @@ -21,6 +21,7 @@ const ( OidOid = 26 Float4Oid = 700 Float8Oid = 701 + BoolArrayOid = 1000 Int2ArrayOid = 1005 Int4ArrayOid = 1007 TextArrayOid = 1009 @@ -1232,6 +1233,33 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } +func encodeBoolArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]bool) + if !ok { + return fmt.Errorf("Expected []bool, received %T", value) + } + + size := 20 + len(slice)*5 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(BoolOid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + func encodeInt2Array(w *WriteBuf, value interface{}) error { slice, ok := value.([]int16) if !ok {