pgx/pgtype/inet.go

241 lines
4.6 KiB
Go

package pgtype
import (
"fmt"
"io"
"net"
"reflect"
"github.com/jackc/pgx/pgio"
)
// Network address family is dependent on server socket.h value for AF_INET.
// In practice, all platforms appear to have the same value. See
// src/include/utils/inet.h for more information.
const (
defaultAFInet = 2
defaultAFInet6 = 3
)
// Inet represents both inet and cidr PostgreSQL types.
type Inet struct {
IPNet *net.IPNet
Status Status
}
func (dst *Inet) ConvertFrom(src interface{}) error {
switch value := src.(type) {
case Inet:
*dst = value
case net.IPNet:
*dst = Inet{IPNet: &value, Status: Present}
case *net.IPNet:
*dst = Inet{IPNet: value, Status: Present}
case net.IP:
bitCount := len(value) * 8
mask := net.CIDRMask(bitCount, bitCount)
*dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present}
case string:
_, ipnet, err := net.ParseCIDR(value)
if err != nil {
return err
}
*dst = Inet{IPNet: ipnet, Status: Present}
default:
if originalSrc, ok := underlyingPtrType(src); ok {
return dst.ConvertFrom(originalSrc)
}
return fmt.Errorf("cannot convert %v to Inet", value)
}
return nil
}
func (src *Inet) AssignTo(dst interface{}) error {
switch v := dst.(type) {
case *net.IPNet:
if src.Status != Present {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
*v = *src.IPNet
case *net.IP:
if src.Status == Present {
if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
*v = src.IPNet.IP
} else {
*v = nil
}
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if src.Status == Null {
el.Set(reflect.Zero(el.Type()))
return nil
}
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return src.AssignTo(el.Interface())
}
}
return fmt.Errorf("cannot decode %v into %T", src, dst)
}
return nil
}
func (dst *Inet) DecodeText(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size == -1 {
*dst = Inet{Status: Null}
return nil
}
buf := make([]byte, int(size))
_, err = io.ReadFull(r, buf)
if err != nil {
return err
}
var ipnet *net.IPNet
if ip := net.ParseIP(string(buf)); ip != nil {
ipv4 := ip.To4()
if ipv4 != nil {
ip = ipv4
}
bitCount := len(ip) * 8
mask := net.CIDRMask(bitCount, bitCount)
ipnet = &net.IPNet{Mask: mask, IP: ip}
} else {
_, ipnet, err = net.ParseCIDR(string(buf))
if err != nil {
return err
}
}
*dst = Inet{IPNet: ipnet, Status: Present}
return nil
}
func (dst *Inet) DecodeBinary(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size == -1 {
*dst = Inet{Status: Null}
return nil
}
if size != 8 && size != 20 {
return fmt.Errorf("Received an invalid size for a inet: %d", size)
}
// ignore family
_, err = pgio.ReadByte(r)
if err != nil {
return err
}
bits, err := pgio.ReadByte(r)
if err != nil {
return err
}
// ignore is_cidr
_, err = pgio.ReadByte(r)
if err != nil {
return err
}
addressLength, err := pgio.ReadByte(r)
if err != nil {
return err
}
var ipnet net.IPNet
ipnet.IP = make(net.IP, int(addressLength))
_, err = r.Read(ipnet.IP)
if err != nil {
return err
}
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
*dst = Inet{IPNet: &ipnet, Status: Present}
return nil
}
func (src Inet) EncodeText(w io.Writer) error {
if done, err := encodeNotPresent(w, src.Status); done {
return err
}
s := src.IPNet.String()
_, err := pgio.WriteInt32(w, int32(len(s)))
if err != nil {
return nil
}
_, err = w.Write([]byte(s))
return err
}
// EncodeBinary encodes src into w.
func (src Inet) EncodeBinary(w io.Writer) error {
if done, err := encodeNotPresent(w, src.Status); done {
return err
}
var size int32
var family byte
switch len(src.IPNet.IP) {
case net.IPv4len:
size = 8
family = defaultAFInet
case net.IPv6len:
size = 20
family = defaultAFInet6
default:
return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP))
}
if _, err := pgio.WriteInt32(w, size); err != nil {
return err
}
if err := pgio.WriteByte(w, family); err != nil {
return err
}
ones, _ := src.IPNet.Mask.Size()
if err := pgio.WriteByte(w, byte(ones)); err != nil {
return err
}
// is_cidr is ignored on server
if err := pgio.WriteByte(w, 0); err != nil {
return err
}
if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil {
return err
}
_, err := w.Write(src.IPNet.IP)
return err
}