diff --git a/hstore.go b/hstore.go index c2de1ccf..f46eeaf6 100644 --- a/hstore.go +++ b/hstore.go @@ -40,6 +40,16 @@ func (dst *Hstore) Set(src interface{}) error { m[k] = Text{String: v, Status: Present} } *dst = Hstore{Map: m, Status: Present} + case map[string]*string: + m := make(map[string]Text, len(value)) + for k, v := range value { + if v == nil { + m[k] = Text{Status: Null} + } else { + m[k] = Text{String: *v, Status: Present} + } + } + *dst = Hstore{Map: m, Status: Present} default: return fmt.Errorf("cannot convert %v to Hstore", src) } @@ -71,6 +81,19 @@ func (src *Hstore) AssignTo(dst interface{}) error { (*v)[k] = val.String } return nil + case *map[string]*string: + *v = make(map[string]*string, len(src.Map)) + for k, val := range src.Map { + switch val.Status { + case Null: + (*v)[k] = nil + case Present: + (*v)[k] = &val.String + default: + return fmt.Errorf("cannot decode %#v into %T", src, dst) + } + } + return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/hstore_test.go b/hstore_test.go index 48b4b42e..73ee0612 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -69,6 +69,50 @@ func TestHstoreTranscode(t *testing.T) { }) } +func TestHstoreTranscodeNullable(t *testing.T) { + text := func(s string, status pgtype.Status) pgtype.Text { + return pgtype.Text{String: s, Status: status} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", pgtype.Null)}, Status: pgtype.Present}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", pgtype.Null)}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", pgtype.Null)}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", pgtype.Null)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", pgtype.Null)}, Status: pgtype.Present}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + func TestHstoreSet(t *testing.T) { successfulTests := []struct { src map[string]string @@ -90,6 +134,27 @@ func TestHstoreSet(t *testing.T) { } } +func TestHstoreSetNullable(t *testing.T) { + successfulTests := []struct { + src map[string]*string + result pgtype.Hstore + }{ + {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + func TestHstoreAssignTo(t *testing.T) { var m map[string]string @@ -113,3 +178,27 @@ func TestHstoreAssignTo(t *testing.T) { } } } + +func TestHstoreAssignToNullable(t *testing.T) { + var m map[string]*string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]*string + expected map[string]*string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +}