mirror of https://github.com/jackc/pgx.git
249 lines
5.0 KiB
Go
249 lines
5.0 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"encoding/hex"
|
|
"fmt"
|
|
)
|
|
|
|
type UUIDScanner interface {
|
|
ScanUUID(v UUID) error
|
|
}
|
|
|
|
type UUIDValuer interface {
|
|
UUIDValue() (UUID, error)
|
|
}
|
|
|
|
type UUID struct {
|
|
Bytes [16]byte
|
|
Valid bool
|
|
}
|
|
|
|
func (b *UUID) ScanUUID(v UUID) error {
|
|
*b = v
|
|
return nil
|
|
}
|
|
|
|
func (b UUID) UUIDValue() (UUID, error) {
|
|
return b, nil
|
|
}
|
|
|
|
// parseUUID converts a string UUID in standard form to a byte array.
|
|
func parseUUID(src string) (dst [16]byte, err error) {
|
|
switch len(src) {
|
|
case 36:
|
|
src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:]
|
|
case 32:
|
|
// dashes already stripped, assume valid
|
|
default:
|
|
// assume invalid.
|
|
return dst, fmt.Errorf("cannot parse UUID %v", src)
|
|
}
|
|
|
|
buf, err := hex.DecodeString(src)
|
|
if err != nil {
|
|
return dst, err
|
|
}
|
|
|
|
copy(dst[:], buf)
|
|
return dst, err
|
|
}
|
|
|
|
// encodeUUID converts a uuid byte array to UUID standard string form.
|
|
func encodeUUID(src [16]byte) string {
|
|
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
|
|
}
|
|
|
|
// Scan implements the database/sql Scanner interface.
|
|
func (dst *UUID) Scan(src interface{}) error {
|
|
if src == nil {
|
|
*dst = UUID{}
|
|
return nil
|
|
}
|
|
|
|
switch src := src.(type) {
|
|
case string:
|
|
buf, err := parseUUID(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*dst = UUID{Bytes: buf, Valid: true}
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("cannot scan %T", src)
|
|
}
|
|
|
|
// Value implements the database/sql/driver Valuer interface.
|
|
func (src UUID) Value() (driver.Value, error) {
|
|
if !src.Valid {
|
|
return nil, nil
|
|
}
|
|
|
|
return encodeUUID(src.Bytes), nil
|
|
}
|
|
|
|
func (src UUID) MarshalJSON() ([]byte, error) {
|
|
if !src.Valid {
|
|
return []byte("null"), nil
|
|
}
|
|
|
|
var buff bytes.Buffer
|
|
buff.WriteByte('"')
|
|
buff.WriteString(encodeUUID(src.Bytes))
|
|
buff.WriteByte('"')
|
|
return buff.Bytes(), nil
|
|
}
|
|
|
|
func (dst *UUID) UnmarshalJSON(src []byte) error {
|
|
if bytes.Compare(src, []byte("null")) == 0 {
|
|
*dst = UUID{}
|
|
return nil
|
|
}
|
|
if len(src) != 38 {
|
|
return fmt.Errorf("invalid length for UUID: %v", len(src))
|
|
}
|
|
buf, err := parseUUID(string(src[1 : len(src)-1]))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*dst = UUID{Bytes: buf, Valid: true}
|
|
return nil
|
|
}
|
|
|
|
type UUIDCodec struct{}
|
|
|
|
func (UUIDCodec) FormatSupported(format int16) bool {
|
|
return format == TextFormatCode || format == BinaryFormatCode
|
|
}
|
|
|
|
func (UUIDCodec) PreferredFormat() int16 {
|
|
return BinaryFormatCode
|
|
}
|
|
|
|
func (UUIDCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
|
|
if _, ok := value.(UUIDValuer); !ok {
|
|
return nil
|
|
}
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
return encodePlanUUIDCodecBinaryUUIDValuer{}
|
|
case TextFormatCode:
|
|
return encodePlanUUIDCodecTextUUIDValuer{}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type encodePlanUUIDCodecBinaryUUIDValuer struct{}
|
|
|
|
func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
|
uuid, err := value.(UUIDValuer).UUIDValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !uuid.Valid {
|
|
return nil, nil
|
|
}
|
|
|
|
return append(buf, uuid.Bytes[:]...), nil
|
|
}
|
|
|
|
type encodePlanUUIDCodecTextUUIDValuer struct{}
|
|
|
|
func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
|
uuid, err := value.(UUIDValuer).UUIDValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !uuid.Valid {
|
|
return nil, nil
|
|
}
|
|
|
|
return append(buf, encodeUUID(uuid.Bytes)...), nil
|
|
}
|
|
|
|
func (UUIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
switch target.(type) {
|
|
case UUIDScanner:
|
|
return scanPlanBinaryUUIDToUUIDScanner{}
|
|
}
|
|
case TextFormatCode:
|
|
switch target.(type) {
|
|
case UUIDScanner:
|
|
return scanPlanTextAnyToUUIDScanner{}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanBinaryUUIDToUUIDScanner struct{}
|
|
|
|
func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst interface{}) error {
|
|
scanner := (dst).(UUIDScanner)
|
|
|
|
if src == nil {
|
|
return scanner.ScanUUID(UUID{})
|
|
}
|
|
|
|
if len(src) != 16 {
|
|
return fmt.Errorf("invalid length for UUID: %v", len(src))
|
|
}
|
|
|
|
uuid := UUID{Valid: true}
|
|
copy(uuid.Bytes[:], src)
|
|
|
|
return scanner.ScanUUID(uuid)
|
|
}
|
|
|
|
type scanPlanTextAnyToUUIDScanner struct{}
|
|
|
|
func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst interface{}) error {
|
|
scanner := (dst).(UUIDScanner)
|
|
|
|
if src == nil {
|
|
return scanner.ScanUUID(UUID{})
|
|
}
|
|
|
|
buf, err := parseUUID(string(src))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return scanner.ScanUUID(UUID{Bytes: buf, Valid: true})
|
|
}
|
|
|
|
func (c UUIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var uuid UUID
|
|
err := codecScan(c, ci, oid, format, src, &uuid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return encodeUUID(uuid.Bytes), nil
|
|
}
|
|
|
|
func (c UUIDCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var uuid UUID
|
|
err := codecScan(c, ci, oid, format, src, &uuid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return uuid.Bytes, nil
|
|
}
|