Check for overflow on uint16 sizes in pgproto3

pull/1927/head
Jack Christensen 2024-03-02 11:56:44 -06:00 committed by Jack Christensen
parent adbb38f298
commit 20344dfae8
9 changed files with 53 additions and 0 deletions

View File

@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -116,11 +118,17 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
@ -132,6 +140,9 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, p...)
}
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -47,6 +48,9 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -48,6 +49,9 @@ func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -48,6 +49,9 @@ func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)

View File

@ -4,6 +4,8 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -66,6 +68,9 @@ func (dst *DataRow) Decode(src []byte) error {
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {

View File

@ -2,6 +2,8 @@ package pgproto3
import (
"encoding/binary"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -74,10 +76,18 @@ func (dst *FunctionCall) Decode(src []byte) error {
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'F')
dst = pgio.AppendUint32(dst, src.Function)
if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -42,6 +44,9 @@ func (dst *ParameterDescription) Decode(src []byte) error {
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 't')
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -60,6 +62,9 @@ func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, src.Query...)
dst = append(dst, 0)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -102,6 +104,9 @@ func (dst *RowDescription) Decode(src []byte) error {
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'T')
if len(src.Fields) > math.MaxUint16 {
return nil, errors.New("too many fields")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)