diff --git a/pgtype/array.go b/pgtype/array.go index 564936d2..3f5ca15b 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -101,7 +101,9 @@ type UntypedTextArray struct { } func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - uta := &UntypedTextArray{} + uta := &UntypedTextArray{ + Elements: []string{}, + } buf := bytes.NewBufferString(src) @@ -112,83 +114,80 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { return nil, fmt.Errorf("invalid array: %v", err) } - var explicitBounds bool - // Array has explicit bounds - if r == '[' { + var explicitDimensions []ArrayDimension + // Array has explicit dimensions + if r == '[' { + // TODO - parse explicit dimensions + panic(explicitDimensions) } - // Parse values + // Consume all initial opening brackets. This provides number of dimensions. + var implicitDimensions []ArrayDimension if r != '{' { return nil, fmt.Errorf("invalid array, expected '{': %v", err) } - if !explicitBounds { - uta.Dimensions = append(uta.Dimensions, ArrayDimension{LowerBound: 1}) - } - currentDimension := 0 - - for currentDimension >= 0 { - - } - - switch r { - case '(': - utr.LowerType = Exclusive - case '[': - utr.LowerType = Inclusive - default: - return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) - } buf.UnreadRune() - if r == ',' { - utr.LowerType = Unbounded - } else { - utr.Lower, err = rangeParseValue(buf) + for { + r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break } } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + elemCount := 0 - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) - } - if r != ',' { - return nil, fmt.Errorf("missing range separator: %v", r) - } + fmt.Println("-------", currentDim, buf.String()) - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) - } - buf.UnreadRune() - - if r == ')' || r == ']' { - utr.UpperType = Unbounded - } else { - utr.Upper, err = rangeParseValue(buf) + for { + r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + fmt.Println("{", buf.String()) + + if counterDim == currentDim { + elemCount++ + } + currentDim++ + case ',': + case '}': + fmt.Println("}", buf.String()) + if counterDim == currentDim { + implicitDimensions[counterDim].Length = int32(elemCount) + elemCount = 0 + } + + currentDim-- + default: + buf.UnreadRune() + fmt.Println("default", buf.String()) + value, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if counterDim == currentDim { + elemCount++ + } + uta.Elements = append(uta.Elements, value) + } + + if currentDim < 0 { + break } - } - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) - } - switch r { - case ')': - utr.UpperType = Exclusive - case ']': - utr.UpperType = Inclusive - default: - return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) } skipWhitespace(buf) @@ -197,7 +196,16 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } - return utr, nil + if len(explicitDimensions) > 0 { + uta.Dimensions = explicitDimensions + } else { + uta.Dimensions = implicitDimensions + if len(uta.Dimensions) == 1 && uta.Dimensions[0].Length == 0 { + uta.Dimensions = []ArrayDimension{} + } + } + + return uta, nil } func skipWhitespace(buf *bytes.Buffer) { @@ -211,13 +219,13 @@ func skipWhitespace(buf *bytes.Buffer) { } } -func rangeParseValue(buf *bytes.Buffer) (string, error) { +func arrayParseValue(buf *bytes.Buffer) (string, error) { r, _, err := buf.ReadRune() if err != nil { return "", err } if r == '"' { - return rangeParseQuotedValue(buf) + return arrayParseQuotedValue(buf) } buf.UnreadRune() @@ -230,12 +238,7 @@ func rangeParseValue(buf *bytes.Buffer) (string, error) { } switch r { - case '\\': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - case ',', '[', ']', '(', ')': + case ',', '}': buf.UnreadRune() return s.String(), nil } @@ -244,7 +247,7 @@ func rangeParseValue(buf *bytes.Buffer) (string, error) { } } -func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { s := &bytes.Buffer{} for { @@ -264,10 +267,8 @@ func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { if err != nil { return "", err } - if r != '"' { - buf.UnreadRune() - return s.String(), nil - } + buf.UnreadRune() + return s.String(), nil } s.WriteRune(r) } diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..6ef65419 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,63 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: []string{}, + Dimensions: []pgtype.ArrayDimension{}, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +}