diff --git a/app.go b/app.go index 70101ee2..7b033b99 100644 --- a/app.go +++ b/app.go @@ -31,7 +31,7 @@ import ( ) // Version of current fiber package -const Version = "2.2.0" +const Version = "2.2.1" // Handler defines a function to serve HTTP requests. type Handler = func(*Ctx) error diff --git a/internal/gotiny/LICENSE b/internal/gotiny/LICENSE new file mode 100644 index 00000000..b48c7ddf --- /dev/null +++ b/internal/gotiny/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2016 zheng-ji.info + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/internal/gotiny/decEngine.go b/internal/gotiny/decEngine.go new file mode 100644 index 00000000..07617a7e --- /dev/null +++ b/internal/gotiny/decEngine.go @@ -0,0 +1,203 @@ +package gotiny + +import ( + "reflect" + "sync" + "time" + "unsafe" +) + +type decEng func(*Decoder, unsafe.Pointer) // 解码器 + +var ( + rt2decEng = map[reflect.Type]decEng{ + reflect.TypeOf((*bool)(nil)).Elem(): decBool, + reflect.TypeOf((*int)(nil)).Elem(): decInt, + reflect.TypeOf((*int8)(nil)).Elem(): decInt8, + reflect.TypeOf((*int16)(nil)).Elem(): decInt16, + reflect.TypeOf((*int32)(nil)).Elem(): decInt32, + reflect.TypeOf((*int64)(nil)).Elem(): decInt64, + reflect.TypeOf((*uint)(nil)).Elem(): decUint, + reflect.TypeOf((*uint8)(nil)).Elem(): decUint8, + reflect.TypeOf((*uint16)(nil)).Elem(): decUint16, + reflect.TypeOf((*uint32)(nil)).Elem(): decUint32, + reflect.TypeOf((*uint64)(nil)).Elem(): decUint64, + reflect.TypeOf((*uintptr)(nil)).Elem(): decUintptr, + reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): decPointer, + reflect.TypeOf((*float32)(nil)).Elem(): decFloat32, + reflect.TypeOf((*float64)(nil)).Elem(): decFloat64, + reflect.TypeOf((*complex64)(nil)).Elem(): decComplex64, + reflect.TypeOf((*complex128)(nil)).Elem(): decComplex128, + reflect.TypeOf((*[]byte)(nil)).Elem(): decBytes, + reflect.TypeOf((*string)(nil)).Elem(): decString, + reflect.TypeOf((*time.Time)(nil)).Elem(): decTime, + reflect.TypeOf((*struct{})(nil)).Elem(): decIgnore, + reflect.TypeOf(nil): decIgnore, + } + + baseDecEngines = []decEng{ + reflect.Invalid: decIgnore, + reflect.Bool: decBool, + reflect.Int: decInt, + reflect.Int8: decInt8, + reflect.Int16: decInt16, + reflect.Int32: decInt32, + reflect.Int64: decInt64, + reflect.Uint: decUint, + reflect.Uint8: decUint8, + reflect.Uint16: decUint16, + reflect.Uint32: decUint32, + reflect.Uint64: decUint64, + reflect.Uintptr: decUintptr, + reflect.UnsafePointer: decPointer, + reflect.Float32: decFloat32, + reflect.Float64: decFloat64, + reflect.Complex64: decComplex64, + reflect.Complex128: decComplex128, + reflect.String: decString, + } + decLock sync.RWMutex +) + +func getDecEngine(rt reflect.Type) decEng { + decLock.RLock() + engine := rt2decEng[rt] + decLock.RUnlock() + if engine != nil { + return engine + } + decLock.Lock() + buildDecEngine(rt, &engine) + decLock.Unlock() + return engine +} + +func buildDecEngine(rt reflect.Type, engPtr *decEng) { + engine, has := rt2decEng[rt] + if has { + *engPtr = engine + return + } + + if _, engine = implementOtherSerializer(rt); engine != nil { + rt2decEng[rt] = engine + *engPtr = engine + return + } + + kind := rt.Kind() + var eEng decEng + switch kind { + case reflect.Ptr: + et := rt.Elem() + defer buildDecEngine(et, &eEng) + engine = func(d *Decoder, p unsafe.Pointer) { + if d.decIsNotNil() { + if isNil(p) { + *(*unsafe.Pointer)(p) = unsafe.Pointer(reflect.New(et).Elem().UnsafeAddr()) + } + eEng(d, *(*unsafe.Pointer)(p)) + } else if !isNil(p) { + *(*unsafe.Pointer)(p) = nil + } + } + case reflect.Array: + l, et := rt.Len(), rt.Elem() + size := et.Size() + defer buildDecEngine(et, &eEng) + engine = func(d *Decoder, p unsafe.Pointer) { + for i := 0; i < l; i++ { + eEng(d, unsafe.Pointer(uintptr(p)+uintptr(i)*size)) + } + } + case reflect.Slice: + et := rt.Elem() + size := et.Size() + defer buildDecEngine(et, &eEng) + engine = func(d *Decoder, p unsafe.Pointer) { + header := (*reflect.SliceHeader)(p) + if d.decIsNotNil() { + l := d.decLength() + if isNil(p) || header.Cap < l { + *header = reflect.SliceHeader{Data: reflect.MakeSlice(rt, l, l).Pointer(), Len: l, Cap: l} + } else { + header.Len = l + } + for i := 0; i < l; i++ { + eEng(d, unsafe.Pointer(header.Data+uintptr(i)*size)) + } + } else if !isNil(p) { + *header = reflect.SliceHeader{} + } + } + case reflect.Map: + kt, vt := rt.Key(), rt.Elem() + skt, svt := reflect.SliceOf(kt), reflect.SliceOf(vt) + var kEng, vEng decEng + defer buildDecEngine(kt, &kEng) + defer buildDecEngine(vt, &vEng) + engine = func(d *Decoder, p unsafe.Pointer) { + if d.decIsNotNil() { + l := d.decLength() + var v reflect.Value + if isNil(p) { + v = reflect.MakeMapWithSize(rt, l) + *(*unsafe.Pointer)(p) = unsafe.Pointer(v.Pointer()) + } else { + v = reflect.NewAt(rt, p).Elem() + } + keys, vals := reflect.MakeSlice(skt, l, l), reflect.MakeSlice(svt, l, l) + for i := 0; i < l; i++ { + key, val := keys.Index(i), vals.Index(i) + kEng(d, unsafe.Pointer(key.UnsafeAddr())) + vEng(d, unsafe.Pointer(val.UnsafeAddr())) + v.SetMapIndex(key, val) + } + } else if !isNil(p) { + *(*unsafe.Pointer)(p) = nil + } + } + case reflect.Struct: + fields, offs := getFieldType(rt, 0) + nf := len(fields) + fEngines := make([]decEng, nf) + defer func() { + for i := 0; i < nf; i++ { + buildDecEngine(fields[i], &fEngines[i]) + } + }() + engine = func(d *Decoder, p unsafe.Pointer) { + for i := 0; i < len(fEngines) && i < len(offs); i++ { + fEngines[i](d, unsafe.Pointer(uintptr(p)+offs[i])) + } + } + case reflect.Interface: + engine = func(d *Decoder, p unsafe.Pointer) { + if d.decIsNotNil() { + name := "" + decString(d, unsafe.Pointer(&name)) + et, has := name2type[name] + if !has { + panic("unknown typ:" + name) + } + v := reflect.NewAt(rt, p).Elem() + var ev reflect.Value + if v.IsNil() || v.Elem().Type() != et { + ev = reflect.New(et).Elem() + } else { + ev = v.Elem() + } + getDecEngine(et)(d, getUnsafePointer(&ev)) + v.Set(ev) + } else if !isNil(p) { + *(*unsafe.Pointer)(p) = nil + } + } + case reflect.Chan, reflect.Func: + panic("not support " + rt.String() + " type") + default: + engine = baseDecEngines[kind] + } + rt2decEng[rt] = engine + *engPtr = engine +} diff --git a/internal/gotiny/decbase.go b/internal/gotiny/decbase.go new file mode 100644 index 00000000..6a5a05cd --- /dev/null +++ b/internal/gotiny/decbase.go @@ -0,0 +1,161 @@ +package gotiny + +import ( + "time" + "unsafe" +) + +func (d *Decoder) decBool() (b bool) { + if d.boolBit == 0 { + d.boolBit = 1 + d.boolPos = d.buf[d.index] + d.index++ + } + b = d.boolPos&d.boolBit != 0 + d.boolBit <<= 1 + return +} + +func (d *Decoder) decUint64() uint64 { + buf, i := d.buf, d.index + x := uint64(buf[i]) + if x < 0x80 { + d.index++ + return x + } + x1 := buf[i+1] + x += uint64(x1) << 7 + if x1 < 0x80 { + d.index += 2 + return x - 1<<7 + } + x2 := buf[i+2] + x += uint64(x2) << 14 + if x2 < 0x80 { + d.index += 3 + return x - (1<<7 + 1<<14) + } + x3 := buf[i+3] + x += uint64(x3) << 21 + if x3 < 0x80 { + d.index += 4 + return x - (1<<7 + 1<<14 + 1<<21) + } + x4 := buf[i+4] + x += uint64(x4) << 28 + if x4 < 0x80 { + d.index += 5 + return x - (1<<7 + 1<<14 + 1<<21 + 1<<28) + } + x5 := buf[i+5] + x += uint64(x5) << 35 + if x5 < 0x80 { + d.index += 6 + return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35) + } + x6 := buf[i+6] + x += uint64(x6) << 42 + if x6 < 0x80 { + d.index += 7 + return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42) + } + x7 := buf[i+7] + x += uint64(x7) << 49 + if x7 < 0x80 { + d.index += 8 + return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49) + } + d.index += 9 + return x + uint64(buf[i+8])<<56 - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49 + 1<<56) +} + +func (d *Decoder) decUint16() uint16 { + buf, i := d.buf, d.index + x := uint16(buf[i]) + if x < 0x80 { + d.index++ + return x + } + x1 := buf[i+1] + x += uint16(x1) << 7 + if x1 < 0x80 { + d.index += 2 + return x - 1<<7 + } + d.index += 3 + return x + uint16(buf[i+2])<<14 - (1<<7 + 1<<14) +} + +func (d *Decoder) decUint32() uint32 { + buf, i := d.buf, d.index + x := uint32(buf[i]) + if x < 0x80 { + d.index++ + return x + } + x1 := buf[i+1] + x += uint32(x1) << 7 + if x1 < 0x80 { + d.index += 2 + return x - 1<<7 + } + x2 := buf[i+2] + x += uint32(x2) << 14 + if x2 < 0x80 { + d.index += 3 + return x - (1<<7 + 1<<14) + } + x3 := buf[i+3] + x += uint32(x3) << 21 + if x3 < 0x80 { + d.index += 4 + return x - (1<<7 + 1<<14 + 1<<21) + } + x4 := buf[i+4] + x += uint32(x4) << 28 + d.index += 5 + return x - (1<<7 + 1<<14 + 1<<21 + 1<<28) +} + +func (d *Decoder) decLength() int { return int(d.decUint32()) } +func (d *Decoder) decIsNotNil() bool { return d.decBool() } + +func decIgnore(*Decoder, unsafe.Pointer) {} +func decBool(d *Decoder, p unsafe.Pointer) { *(*bool)(p) = d.decBool() } +func decInt(d *Decoder, p unsafe.Pointer) { *(*int)(p) = int(uint64ToInt64(d.decUint64())) } +func decInt8(d *Decoder, p unsafe.Pointer) { *(*int8)(p) = int8(d.buf[d.index]); d.index++ } +func decInt16(d *Decoder, p unsafe.Pointer) { *(*int16)(p) = uint16ToInt16(d.decUint16()) } +func decInt32(d *Decoder, p unsafe.Pointer) { *(*int32)(p) = uint32ToInt32(d.decUint32()) } +func decInt64(d *Decoder, p unsafe.Pointer) { *(*int64)(p) = uint64ToInt64(d.decUint64()) } +func decUint(d *Decoder, p unsafe.Pointer) { *(*uint)(p) = uint(d.decUint64()) } +func decUint8(d *Decoder, p unsafe.Pointer) { *(*uint8)(p) = d.buf[d.index]; d.index++ } +func decUint16(d *Decoder, p unsafe.Pointer) { *(*uint16)(p) = d.decUint16() } +func decUint32(d *Decoder, p unsafe.Pointer) { *(*uint32)(p) = d.decUint32() } +func decUint64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() } +func decUintptr(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) } +func decPointer(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) } +func decFloat32(d *Decoder, p unsafe.Pointer) { *(*float32)(p) = uint32ToFloat32(d.decUint32()) } +func decFloat64(d *Decoder, p unsafe.Pointer) { *(*float64)(p) = uint64ToFloat64(d.decUint64()) } +func decTime(d *Decoder, p unsafe.Pointer) { *(*time.Time)(p) = time.Unix(0, int64(d.decUint64())) } +func decComplex64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() } +func decComplex128(d *Decoder, p unsafe.Pointer) { + *(*uint64)(p) = d.decUint64() + *(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size)) = d.decUint64() +} + +func decString(d *Decoder, p unsafe.Pointer) { + l, val := int(d.decUint32()), (*string)(p) + *val = string(d.buf[d.index : d.index+l]) + d.index += l +} + +func decBytes(d *Decoder, p unsafe.Pointer) { + bytes := (*[]byte)(p) + if d.decIsNotNil() { + l := int(d.decUint32()) + *bytes = d.buf[d.index : d.index+l] + d.index += l + } else if !isNil(p) { + *bytes = nil + } +} diff --git a/internal/gotiny/decoder.go b/internal/gotiny/decoder.go new file mode 100644 index 00000000..abf58c60 --- /dev/null +++ b/internal/gotiny/decoder.go @@ -0,0 +1,97 @@ +package gotiny + +import ( + "reflect" + "unsafe" +) + +type Decoder struct { + buf []byte //buf + index int //下一个要使用的字节在buf中的下标 + boolPos byte //下一次要读取的bool在buf中的下标,即buf[boolPos] + boolBit byte //下一次要读取的bool的buf[boolPos]中的bit位 + + engines []decEng //解码器集合 + length int //解码器数量 +} + +func Unmarshal(buf []byte, is ...interface{}) int { + return NewDecoderWithPtr(is...).Decode(buf, is...) +} + +func NewDecoderWithPtr(is ...interface{}) *Decoder { + l := len(is) + engines := make([]decEng, l) + for i := 0; i < l; i++ { + rt := reflect.TypeOf(is[i]) + if rt.Kind() != reflect.Ptr { + panic("must a pointer type!") + } + engines[i] = getDecEngine(rt.Elem()) + } + return &Decoder{ + length: l, + engines: engines, + } +} + +func NewDecoder(is ...interface{}) *Decoder { + l := len(is) + engines := make([]decEng, l) + for i := 0; i < l; i++ { + engines[i] = getDecEngine(reflect.TypeOf(is[i])) + } + return &Decoder{ + length: l, + engines: engines, + } +} + +func NewDecoderWithType(ts ...reflect.Type) *Decoder { + l := len(ts) + des := make([]decEng, l) + for i := 0; i < l; i++ { + des[i] = getDecEngine(ts[i]) + } + return &Decoder{ + length: l, + engines: des, + } +} + +func (d *Decoder) reset() int { + index := d.index + d.index = 0 + d.boolPos = 0 + d.boolBit = 0 + return index +} + +// is is pointer of variable +func (d *Decoder) Decode(buf []byte, is ...interface{}) int { + d.buf = buf + engines := d.engines + for i := 0; i < len(engines) && i < len(is); i++ { + engines[i](d, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1]) + } + return d.reset() +} + +// ps is a unsafe.Pointer of the variable +func (d *Decoder) DecodePtr(buf []byte, ps ...unsafe.Pointer) int { + d.buf = buf + engines := d.engines + for i := 0; i < len(engines) && i < len(ps); i++ { + engines[i](d, ps[i]) + } + return d.reset() +} + +func (d *Decoder) DecodeValue(buf []byte, vs ...reflect.Value) int { + d.buf = buf + engines := d.engines + for i := 0; i < len(engines) && i < len(vs); i++ { + engines[i](d, unsafe.Pointer(vs[i].UnsafeAddr())) + } + return d.reset() +} diff --git a/internal/gotiny/encEngine.go b/internal/gotiny/encEngine.go new file mode 100644 index 00000000..9964ea17 --- /dev/null +++ b/internal/gotiny/encEngine.go @@ -0,0 +1,196 @@ +package gotiny + +import ( + "reflect" + "sync" + "time" + "unsafe" +) + +type encEng func(*Encoder, unsafe.Pointer) //编码器 + +var ( + rt2encEng = map[reflect.Type]encEng{ + reflect.TypeOf((*bool)(nil)).Elem(): encBool, + reflect.TypeOf((*int)(nil)).Elem(): encInt, + reflect.TypeOf((*int8)(nil)).Elem(): encInt8, + reflect.TypeOf((*int16)(nil)).Elem(): encInt16, + reflect.TypeOf((*int32)(nil)).Elem(): encInt32, + reflect.TypeOf((*int64)(nil)).Elem(): encInt64, + reflect.TypeOf((*uint)(nil)).Elem(): encUint, + reflect.TypeOf((*uint8)(nil)).Elem(): encUint8, + reflect.TypeOf((*uint16)(nil)).Elem(): encUint16, + reflect.TypeOf((*uint32)(nil)).Elem(): encUint32, + reflect.TypeOf((*uint64)(nil)).Elem(): encUint64, + reflect.TypeOf((*uintptr)(nil)).Elem(): encUintptr, + reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): encPointer, + reflect.TypeOf((*float32)(nil)).Elem(): encFloat32, + reflect.TypeOf((*float64)(nil)).Elem(): encFloat64, + reflect.TypeOf((*complex64)(nil)).Elem(): encComplex64, + reflect.TypeOf((*complex128)(nil)).Elem(): encComplex128, + reflect.TypeOf((*[]byte)(nil)).Elem(): encBytes, + reflect.TypeOf((*string)(nil)).Elem(): encString, + reflect.TypeOf((*time.Time)(nil)).Elem(): encTime, + reflect.TypeOf((*struct{})(nil)).Elem(): encIgnore, + reflect.TypeOf(nil): encIgnore, + } + + encEngines = [...]encEng{ + reflect.Invalid: encIgnore, + reflect.Bool: encBool, + reflect.Int: encInt, + reflect.Int8: encInt8, + reflect.Int16: encInt16, + reflect.Int32: encInt32, + reflect.Int64: encInt64, + reflect.Uint: encUint, + reflect.Uint8: encUint8, + reflect.Uint16: encUint16, + reflect.Uint32: encUint32, + reflect.Uint64: encUint64, + reflect.Uintptr: encUintptr, + reflect.UnsafePointer: encPointer, + reflect.Float32: encFloat32, + reflect.Float64: encFloat64, + reflect.Complex64: encComplex64, + reflect.Complex128: encComplex128, + reflect.String: encString, + } + + encLock sync.RWMutex +) + +func UnusedUnixNanoEncodeTimeType() { + delete(rt2encEng, reflect.TypeOf((*time.Time)(nil)).Elem()) + delete(rt2decEng, reflect.TypeOf((*time.Time)(nil)).Elem()) +} + +func getEncEngine(rt reflect.Type) encEng { + encLock.RLock() + engine := rt2encEng[rt] + encLock.RUnlock() + if engine != nil { + return engine + } + encLock.Lock() + buildEncEngine(rt, &engine) + encLock.Unlock() + return engine +} + +func buildEncEngine(rt reflect.Type, engPtr *encEng) { + engine := rt2encEng[rt] + if engine != nil { + *engPtr = engine + return + } + + if engine, _ = implementOtherSerializer(rt); engine != nil { + rt2encEng[rt] = engine + *engPtr = engine + return + } + + kind := rt.Kind() + var eEng encEng + switch kind { + case reflect.Ptr: + defer buildEncEngine(rt.Elem(), &eEng) + engine = func(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + eEng(e, *(*unsafe.Pointer)(p)) + } + } + case reflect.Array: + et, l := rt.Elem(), rt.Len() + defer buildEncEngine(et, &eEng) + size := et.Size() + engine = func(e *Encoder, p unsafe.Pointer) { + for i := 0; i < l; i++ { + eEng(e, unsafe.Pointer(uintptr(p)+uintptr(i)*size)) + } + } + case reflect.Slice: + et := rt.Elem() + size := et.Size() + defer buildEncEngine(et, &eEng) + engine = func(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + header := (*reflect.SliceHeader)(p) + l := header.Len + e.encLength(l) + for i := 0; i < l; i++ { + eEng(e, unsafe.Pointer(header.Data+uintptr(i)*size)) + } + } + } + case reflect.Map: + var kEng encEng + defer buildEncEngine(rt.Key(), &kEng) + defer buildEncEngine(rt.Elem(), &eEng) + engine = func(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + v := reflect.NewAt(rt, p).Elem() + e.encLength(v.Len()) + keys := v.MapKeys() + for i := 0; i < len(keys); i++ { + val := v.MapIndex(keys[i]) + kEng(e, getUnsafePointer(&keys[i])) + eEng(e, getUnsafePointer(&val)) + } + } + } + case reflect.Struct: + fields, offs := getFieldType(rt, 0) + nf := len(fields) + fEngines := make([]encEng, nf) + defer func() { + for i := 0; i < nf; i++ { + buildEncEngine(fields[i], &fEngines[i]) + } + }() + engine = func(e *Encoder, p unsafe.Pointer) { + for i := 0; i < len(fEngines) && i < len(offs); i++ { + fEngines[i](e, unsafe.Pointer(uintptr(p)+offs[i])) + } + } + case reflect.Interface: + if rt.NumMethod() > 0 { + engine = func(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + v := reflect.ValueOf(*(*interface { + M() + })(p)) + et := v.Type() + e.encString(getNameOfType(et)) + getEncEngine(et)(e, getUnsafePointer(&v)) + } + } + } else { + engine = func(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + v := reflect.ValueOf(*(*interface{})(p)) + et := v.Type() + e.encString(getNameOfType(et)) + getEncEngine(et)(e, getUnsafePointer(&v)) + } + } + } + case reflect.Chan, reflect.Func: + panic("not support " + rt.String() + " type") + default: + engine = encEngines[kind] + } + rt2encEng[rt] = engine + *engPtr = engine +} diff --git a/internal/gotiny/encbase.go b/internal/gotiny/encbase.go new file mode 100644 index 00000000..0b94fdea --- /dev/null +++ b/internal/gotiny/encbase.go @@ -0,0 +1,108 @@ +package gotiny + +import ( + "time" + "unsafe" +) + +func (e *Encoder) encBool(v bool) { + if e.boolBit == 0 { + e.boolPos = len(e.buf) + e.buf = append(e.buf, 0) + e.boolBit = 1 + } + if v { + e.buf[e.boolPos] |= e.boolBit + } + e.boolBit <<= 1 +} + +func (e *Encoder) encUint64(v uint64) { + switch { + case v < 1<<7-1: + e.buf = append(e.buf, byte(v)) + case v < 1<<14-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)) + case v < 1<<21-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)) + case v < 1<<28-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)) + case v < 1<<35-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)) + case v < 1<<42-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)) + case v < 1<<49-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)) + case v < 1<<56-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49)) + default: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49)|0x80, byte(v>>56)) + } +} + +func (e *Encoder) encUint16(v uint16) { + if v < 1<<7-1 { + e.buf = append(e.buf, byte(v)) + } else if v < 1<<14-1 { + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)) + } else { + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)) + } +} + +func (e *Encoder) encUint32(v uint32) { + switch { + case v < 1<<7-1: + e.buf = append(e.buf, byte(v)) + case v < 1<<14-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)) + case v < 1<<21-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)) + case v < 1<<28-1: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)) + default: + e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)) + } +} + +func (e *Encoder) encLength(v int) { e.encUint32(uint32(v)) } +func (e *Encoder) encString(s string) { e.encUint32(uint32(len(s))); e.buf = append(e.buf, s...) } +func (e *Encoder) encIsNotNil(v bool) { e.encBool(v) } + +func encIgnore(*Encoder, unsafe.Pointer) {} +func encBool(e *Encoder, p unsafe.Pointer) { e.encBool(*(*bool)(p)) } +func encInt(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(int64(*(*int)(p)))) } +func encInt8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) } +func encInt16(e *Encoder, p unsafe.Pointer) { e.encUint16(int16ToUint16(*(*int16)(p))) } +func encInt32(e *Encoder, p unsafe.Pointer) { e.encUint32(int32ToUint32(*(*int32)(p))) } +func encInt64(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(*(*int64)(p))) } +func encUint8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) } +func encUint16(e *Encoder, p unsafe.Pointer) { e.encUint16(*(*uint16)(p)) } +func encUint32(e *Encoder, p unsafe.Pointer) { e.encUint32(*(*uint32)(p)) } +func encUint64(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint64)(p))) } +func encUint(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint)(p))) } +func encUintptr(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) } +func encPointer(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) } +func encFloat32(e *Encoder, p unsafe.Pointer) { e.encUint32(float32ToUint32(p)) } +func encFloat64(e *Encoder, p unsafe.Pointer) { e.encUint64(float64ToUint64(p)) } +func encString(e *Encoder, p unsafe.Pointer) { + s := *(*string)(p) + e.encUint32(uint32(len(s))) + e.buf = append(e.buf, s...) +} +func encTime(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64((*time.Time)(p).UnixNano())) } +func encComplex64(e *Encoder, p unsafe.Pointer) { e.encUint64(*(*uint64)(p)) } +func encComplex128(e *Encoder, p unsafe.Pointer) { + e.encUint64(*(*uint64)(p)) + e.encUint64(*(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size))) +} + +func encBytes(e *Encoder, p unsafe.Pointer) { + isNotNil := !isNil(p) + e.encIsNotNil(isNotNil) + if isNotNil { + buf := *(*[]byte)(p) + e.encLength(len(buf)) + e.buf = append(e.buf, buf...) + } +} diff --git a/internal/gotiny/encoder.go b/internal/gotiny/encoder.go new file mode 100644 index 00000000..de5a46cd --- /dev/null +++ b/internal/gotiny/encoder.go @@ -0,0 +1,103 @@ +package gotiny + +import ( + "reflect" + "unsafe" +) + +type Encoder struct { + buf []byte //编码目的数组 + off int + boolPos int //下一次要设置的bool在buf中的下标,即buf[boolPos] + boolBit byte //下一次要设置的bool的buf[boolPos]中的bit位 + + engines []encEng + length int +} + +func Marshal(is ...interface{}) []byte { + return NewEncoderWithPtr(is...).Encode(is...) +} + +// 创建一个编码ps 指向类型的编码器 +func NewEncoderWithPtr(ps ...interface{}) *Encoder { + l := len(ps) + engines := make([]encEng, l) + for i := 0; i < l; i++ { + rt := reflect.TypeOf(ps[i]) + if rt.Kind() != reflect.Ptr { + panic("must a pointer type!") + } + engines[i] = getEncEngine(rt.Elem()) + } + return &Encoder{ + length: l, + engines: engines, + } +} + +// 创建一个编码is 类型的编码器 +func NewEncoder(is ...interface{}) *Encoder { + l := len(is) + engines := make([]encEng, l) + for i := 0; i < l; i++ { + engines[i] = getEncEngine(reflect.TypeOf(is[i])) + } + return &Encoder{ + length: l, + engines: engines, + } +} + +func NewEncoderWithType(ts ...reflect.Type) *Encoder { + l := len(ts) + engines := make([]encEng, l) + for i := 0; i < l; i++ { + engines[i] = getEncEngine(ts[i]) + } + return &Encoder{ + length: l, + engines: engines, + } +} + +// 入参是要编码值的指针 +func (e *Encoder) Encode(is ...interface{}) []byte { + engines := e.engines + for i := 0; i < len(engines) && i < len(is); i++ { + engines[i](e, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1]) + } + return e.reset() +} + +// 入参是要编码的值得unsafe.Pointer 指针 +func (e *Encoder) EncodePtr(ps ...unsafe.Pointer) []byte { + engines := e.engines + for i := 0; i < len(engines) && i < len(ps); i++ { + engines[i](e, ps[i]) + } + return e.reset() +} + +// vs 是持有要编码的值 +func (e *Encoder) EncodeValue(vs ...reflect.Value) []byte { + engines := e.engines + for i := 0; i < len(engines) && i < len(vs); i++ { + engines[i](e, getUnsafePointer(&vs[i])) + } + return e.reset() +} + +// 编码产生的数据将append到buf上 +func (e *Encoder) AppendTo(buf []byte) { + e.off = len(buf) + e.buf = buf +} + +func (e *Encoder) reset() []byte { + buf := e.buf + e.buf = buf[:e.off] + e.boolBit = 0 + e.boolPos = 0 + return buf +} diff --git a/internal/gotiny/register.go b/internal/gotiny/register.go new file mode 100644 index 00000000..94732a1d --- /dev/null +++ b/internal/gotiny/register.go @@ -0,0 +1,144 @@ +package gotiny + +import ( + "reflect" + "strconv" +) + +var ( + type2name = map[reflect.Type]string{} + name2type = map[string]reflect.Type{} +) + +func GetName(obj interface{}) string { + return GetNameByType(reflect.TypeOf(obj)) +} +func GetNameByType(rt reflect.Type) string { + return string(getName([]byte(nil), rt)) +} + +func getName(prefix []byte, rt reflect.Type) []byte { + if rt == nil || rt.Kind() == reflect.Invalid { + return append(prefix, []byte("")...) + } + if rt.Name() == "" { //未命名的,组合类型 + switch rt.Kind() { + case reflect.Ptr: + return getName(append(prefix, '*'), rt.Elem()) + case reflect.Array: + return getName(append(prefix, "["+strconv.Itoa(rt.Len())+"]"...), rt.Elem()) + case reflect.Slice: + return getName(append(prefix, '[', ']'), rt.Elem()) + case reflect.Struct: + prefix = append(prefix, "struct {"...) + nf := rt.NumField() + if nf > 0 { + prefix = append(prefix, ' ') + } + for i := 0; i < nf; i++ { + field := rt.Field(i) + if field.Anonymous { + prefix = getName(prefix, field.Type) + } else { + prefix = getName(append(prefix, field.Name+" "...), field.Type) + } + if i != nf-1 { + prefix = append(prefix, ';', ' ') + } else { + prefix = append(prefix, ' ') + } + } + return append(prefix, '}') + case reflect.Map: + return getName(append(getName(append(prefix, "map["...), rt.Key()), ']'), rt.Elem()) + case reflect.Interface: + prefix = append(prefix, "interface {"...) + nm := rt.NumMethod() + if nm > 0 { + prefix = append(prefix, ' ') + } + for i := 0; i < nm; i++ { + method := rt.Method(i) + fn := getName([]byte(nil), method.Type) + prefix = append(prefix, method.Name+string(fn[4:])...) + if i != nm-1 { + prefix = append(prefix, ';', ' ') + } else { + prefix = append(prefix, ' ') + } + } + return append(prefix, '}') + case reflect.Func: + prefix = append(prefix, "func("...) + for i := 0; i < rt.NumIn(); i++ { + prefix = getName(prefix, rt.In(i)) + if i != rt.NumIn()-1 { + prefix = append(prefix, ',', ' ') + } + } + prefix = append(prefix, ')') + no := rt.NumOut() + if no > 0 { + prefix = append(prefix, ' ') + } + if no > 1 { + prefix = append(prefix, '(') + } + for i := 0; i < no; i++ { + prefix = getName(prefix, rt.Out(i)) + if i != no-1 { + prefix = append(prefix, ',', ' ') + } + } + if no > 1 { + prefix = append(prefix, ')') + } + return prefix + } + } + + if rt.PkgPath() == "" { + prefix = append(prefix, rt.Name()...) + } else { + prefix = append(prefix, rt.PkgPath()+"."+rt.Name()...) + } + return prefix +} + +func getNameOfType(rt reflect.Type) string { + if name, has := type2name[rt]; has { + return name + } else { + return registerType(rt) + } +} + +func Register(i interface{}) string { + return registerType(reflect.TypeOf(i)) +} + +func registerType(rt reflect.Type) string { + name := GetNameByType(rt) + RegisterName(name, rt) + return name +} + +func RegisterName(name string, rt reflect.Type) { + if name == "" { + panic("attempt to register empty name") + } + + if rt == nil || rt.Kind() == reflect.Invalid { + panic("attempt to register nil type or invalid type") + } + + if _, has := type2name[rt]; has { + panic("gotiny: registering duplicate types for " + GetNameByType(rt)) + } + + if _, has := name2type[name]; has { + panic("gotiny: registering name" + name + " is exist") + } + name2type[name] = rt + type2name[rt] = name +} diff --git a/internal/gotiny/unsafe.go b/internal/gotiny/unsafe.go new file mode 100644 index 00000000..16acc928 --- /dev/null +++ b/internal/gotiny/unsafe.go @@ -0,0 +1,57 @@ +package gotiny + +import ( + "reflect" + "unsafe" +) + +const ( + kindDirectIface = 1 << 5 +) + +// rtype is the common implementation of most values. +// It is embedded in other struct types. +// +// rtype must be kept in sync with reflect/type.go:/^type._type. +type rtype struct { + _ uintptr + _ uintptr // number of bytes in the type that can contain pointers + _ uint32 // hash of type; avoids computation in hash tables + _ uint8 // extra type information flags + _ uint8 // alignment of variable with this type + _ uint8 // alignment of struct field with this type + kind uint8 // enumeration for C + _ uintptr // algorithm table + _ uintptr // garbage collection data + _ int32 // string form + _ int32 // type for pointer to this type, may be zero +} + +// ifaceIndir reports whether t is stored indirectly in an interface value. +func ifaceDirect(t *rtype) bool { + return t.kind&kindDirectIface != 0 +} + +func directType(rt *reflect.Type) bool { + return ifaceDirect((*rtype)((*[2]unsafe.Pointer)(unsafe.Pointer(rt))[1])) +} + +type refVal struct { + _ unsafe.Pointer + ptr unsafe.Pointer + flag flag +} + +type flag uintptr + +//go:linkname flagIndir reflect.flagIndir +const flagIndir flag = 1 << 7 + +func getUnsafePointer(rv *reflect.Value) unsafe.Pointer { + vv := (*refVal)(unsafe.Pointer(rv)) + if vv.flag&flagIndir == 0 { + return unsafe.Pointer(&vv.ptr) + } else { + return vv.ptr + } +} diff --git a/internal/gotiny/utils.go b/internal/gotiny/utils.go new file mode 100644 index 00000000..e42393fd --- /dev/null +++ b/internal/gotiny/utils.go @@ -0,0 +1,185 @@ +package gotiny + +import ( + "encoding" + "encoding/gob" + "reflect" + "strings" + "unsafe" +) + +const ( + ptr1Size = 4 << (^uintptr(0) >> 63) // unsafe.Sizeof(uintptr(0)) but an ideal const +) + +func float64ToUint64(v unsafe.Pointer) uint64 { + return reverse64Byte(*(*uint64)(v)) +} + +func uint64ToFloat64(u uint64) float64 { + u = reverse64Byte(u) + return *((*float64)(unsafe.Pointer(&u))) +} + +func reverse64Byte(u uint64) uint64 { + u = (u << 32) | (u >> 32) + u = ((u << 16) & 0xFFFF0000FFFF0000) | ((u >> 16) & 0xFFFF0000FFFF) + u = ((u << 8) & 0xFF00FF00FF00FF00) | ((u >> 8) & 0xFF00FF00FF00FF) + return u +} + +func float32ToUint32(v unsafe.Pointer) uint32 { + return reverse32Byte(*(*uint32)(v)) +} + +func uint32ToFloat32(u uint32) float32 { + u = reverse32Byte(u) + return *((*float32)(unsafe.Pointer(&u))) +} + +func reverse32Byte(u uint32) uint32 { + u = (u << 16) | (u >> 16) + return ((u << 8) & 0xFF00FF00) | ((u >> 8) & 0xFF00FF) +} + +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +func int64ToUint64(v int64) uint64 { + return uint64((v << 1) ^ (v >> 63)) +} + +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +func uint64ToInt64(u uint64) int64 { + v := int64(u) + return (-(v & 1)) ^ (v>>1)&0x7FFFFFFFFFFFFFFF +} + +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +func int32ToUint32(v int32) uint32 { + return uint32((v << 1) ^ (v >> 31)) +} + +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +func uint32ToInt32(u uint32) int32 { + v := int32(u) + return (-(v & 1)) ^ (v>>1)&0x7FFFFFFF +} + +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +func int16ToUint16(v int16) uint16 { + return uint16((v << 1) ^ (v >> 15)) +} + +// uint 9 7 5 3 1 0 2 4 6 8 10 12 +// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6 +func uint16ToInt16(u uint16) int16 { + v := int16(u) + return (-(v & 1)) ^ (v>>1)&0x7FFF +} + +func isNil(p unsafe.Pointer) bool { + return *(*unsafe.Pointer)(p) == nil +} + +type gobInter interface { + gob.GobEncoder + gob.GobDecoder +} + +type binInter interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + +// 只应该由指针来实现该接口 +type GoTinySerializer interface { + // 编码方法,将对象的序列化结果append到入参数并返回,方法不应该修改入参数值原有的值 + GotinyEncode([]byte) []byte + // 解码方法,将入参解码到对象里并返回使用的长度。方法从入参的第0个字节开始使用,并且不应该修改入参中的任何数据 + GotinyDecode([]byte) int +} + +func implementOtherSerializer(rt reflect.Type) (encEng encEng, decEng decEng) { + rtNil := reflect.Zero(reflect.PtrTo(rt)).Interface() + if _, ok := rtNil.(GoTinySerializer); ok { + encEng = func(e *Encoder, p unsafe.Pointer) { + e.buf = reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyEncode(e.buf) + } + decEng = func(d *Decoder, p unsafe.Pointer) { + d.index += reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyDecode(d.buf[d.index:]) + } + return + } + + if _, ok := rtNil.(binInter); ok { + encEng = func(e *Encoder, p unsafe.Pointer) { + buf, err := reflect.NewAt(rt, p).Interface().(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + panic(err) + } + e.encLength(len(buf)) + e.buf = append(e.buf, buf...) + } + + decEng = func(d *Decoder, p unsafe.Pointer) { + length := d.decLength() + start := d.index + d.index += length + if err := reflect.NewAt(rt, p).Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(d.buf[start:d.index]); err != nil { + panic(err) + } + } + return + } + + if _, ok := rtNil.(gobInter); ok { + encEng = func(e *Encoder, p unsafe.Pointer) { + buf, err := reflect.NewAt(rt, p).Interface().(gob.GobEncoder).GobEncode() + if err != nil { + panic(err) + } + e.encLength(len(buf)) + e.buf = append(e.buf, buf...) + } + decEng = func(d *Decoder, p unsafe.Pointer) { + length := d.decLength() + start := d.index + d.index += length + if err := reflect.NewAt(rt, p).Interface().(gob.GobDecoder).GobDecode(d.buf[start:d.index]); err != nil { + panic(err) + } + } + } + return +} + +// rt.kind is reflect.struct +func getFieldType(rt reflect.Type, baseOff uintptr) (fields []reflect.Type, offs []uintptr) { + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + if ignoreField(field) { + continue + } + ft := field.Type + if ft.Kind() == reflect.Struct { + if _, engine := implementOtherSerializer(ft); engine == nil { + fFields, fOffs := getFieldType(ft, field.Offset+baseOff) + fields = append(fields, fFields...) + offs = append(offs, fOffs...) + continue + } + } + fields = append(fields, ft) + offs = append(offs, field.Offset+baseOff) + } + return +} + +func ignoreField(field reflect.StructField) bool { + tinyTag, ok := field.Tag.Lookup("gotiny") + return ok && strings.TrimSpace(tinyTag) == "-" +} diff --git a/middleware/session/README.md b/middleware/session/README.md index 9ecaca51..525541f3 100644 --- a/middleware/session/README.md +++ b/middleware/session/README.md @@ -10,9 +10,22 @@ Session middleware for [Fiber](https://github.com/gofiber/fiber) ### Signatures ```go -func New(config ...Config) fiber.Handler +func New(config ...Config) *Store +func (s *Store) Get(c *fiber.Ctx) (*Session, error) +func (s *Store) Reset() error + +func (s *Session) Get(key string) interface{} +func (s *Session) Set(key string, val interface{}) +func (s *Session) Delete(key string) +func (s *Session) Destroy() error +func (s *Session) Regenerate() error +func (s *Session) Save() error +func (s *Session) Fresh() bool +func (s *Session) ID() string ``` +**⚠ _Storing `interface{}` values are limited to built-ins Go types_** + ### Examples Import the middleware package that is part of the Fiber web framework ```go @@ -35,9 +48,6 @@ app.Get("/", func(c *fiber.Ctx) error { panic(err) } - // save session - defer sess.Save() - // Get value name := sess.Get("name") @@ -52,6 +62,11 @@ app.Get("/", func(c *fiber.Ctx) error { panic(err) } + // save session + if err := sess.Save(); err != nil { + panic(err) + } + return fmt.Fprintf(ctx, "Welcome %v", name) }) ``` diff --git a/middleware/session/data.go b/middleware/session/data.go index dbc32592..8d682ab4 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -1,24 +1,23 @@ package session -import "sync" +import ( + "sync" +) // go:generate msgp // msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported // don't forget to replace the msgp import path to: // "github.com/gofiber/fiber/v2/internal/msgp" type data struct { - d []kv -} - -// go:generate msgp -type kv struct { - k string - v interface{} + sync.RWMutex `gotiny:"-"` + d map[string]interface{} `gotiny:"d"` } var dataPool = sync.Pool{ New: func() interface{} { - return new(data) + d := new(data) + d.d = make(map[string]interface{}) + return d }, } @@ -32,70 +31,32 @@ func releaseData(d *data) { } func (d *data) Reset() { - d.d = d.d[:0] + d.Lock() + for key := range d.d { + delete(d.d, key) + } + d.Unlock() } func (d *data) Get(key string) interface{} { - idx := d.indexOf(key) - if idx > -1 { - return d.d[idx].v - } - return nil + d.RLock() + v := d.d[key] + d.RUnlock() + return v } func (d *data) Set(key string, value interface{}) { - idx := d.indexOf(key) - if idx > -1 { - kv := &d.d[idx] - kv.v = value - } else { - d.append(key, value) - } + d.Lock() + d.d[key] = value + d.Unlock() } func (d *data) Delete(key string) { - idx := d.indexOf(key) - if idx > -1 { - n := len(d.d) - 1 - d.swap(idx, n) - d.d = d.d[:n] - } + d.Lock() + delete(d.d, key) + d.Unlock() } func (d *data) Len() int { return len(d.d) } - -func (d *data) swap(i, j int) { - iKey, iValue := d.d[i].k, d.d[i].v - jKey, jValue := d.d[j].k, d.d[j].v - - d.d[i].k, d.d[i].v = jKey, jValue - d.d[j].k, d.d[j].v = iKey, iValue -} - -func (d *data) allocPage() *kv { - n := len(d.d) - if cap(d.d) > n { - d.d = d.d[:n+1] - } else { - d.d = append(d.d, kv{}) - } - return &d.d[n] -} - -func (d *data) append(key string, value interface{}) { - kv := d.allocPage() - kv.k = key - kv.v = value -} - -func (d *data) indexOf(key string) int { - n := len(d.d) - for i := 0; i < n; i++ { - if d.d[i].k == key { - return i - } - } - return -1 -} diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go deleted file mode 100644 index 1586f922..00000000 --- a/middleware/session/data_msgp.go +++ /dev/null @@ -1,365 +0,0 @@ -package session - -// Code generated by github.com/tinylib/msgp DO NOT EDIT. - -import ( - "github.com/gofiber/fiber/v2/internal/msgp" -) - -// DecodeMsg implements msgp.Decodable -func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, err = dc.ReadMapKeyPtr() - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "d": - var zb0002 uint32 - zb0002, err = dc.ReadArrayHeader() - if err != nil { - err = msgp.WrapError(err, "d") - return - } - if cap(z.d) >= int(zb0002) { - z.d = (z.d)[:zb0002] - } else { - z.d = make([]kv, zb0002) - } - for za0001 := range z.d { - var zb0003 uint32 - zb0003, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - for zb0003 > 0 { - zb0003-- - field, err = dc.ReadMapKeyPtr() - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - switch msgp.UnsafeString(field) { - case "k": - z.d[za0001].k, err = dc.ReadString() - if err != nil { - err = msgp.WrapError(err, "d", za0001, "k") - return - } - case "v": - z.d[za0001].v, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "d", za0001, "v") - return - } - default: - err = dc.Skip() - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - } - } - } - default: - err = dc.Skip() - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - return -} - -// EncodeMsg implements msgp.Encodable -func (z *data) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 1 - // write "d" - err = en.Append(0x81, 0xa1, 0x64) - if err != nil { - return - } - err = en.WriteArrayHeader(uint32(len(z.d))) - if err != nil { - err = msgp.WrapError(err, "d") - return - } - for za0001 := range z.d { - // map header, size 2 - // write "k" - err = en.Append(0x82, 0xa1, 0x6b) - if err != nil { - return - } - err = en.WriteString(z.d[za0001].k) - if err != nil { - err = msgp.WrapError(err, "d", za0001, "k") - return - } - // write "v" - err = en.Append(0xa1, 0x76) - if err != nil { - return - } - err = en.WriteIntf(z.d[za0001].v) - if err != nil { - err = msgp.WrapError(err, "d", za0001, "v") - return - } - } - return -} - -// MarshalMsg implements msgp.Marshaler -func (z *data) MarshalMsg(b []byte) (o []byte, err error) { - o = msgp.Require(b, z.Msgsize()) - // map header, size 1 - // string "d" - o = append(o, 0x81, 0xa1, 0x64) - o = msgp.AppendArrayHeader(o, uint32(len(z.d))) - for za0001 := range z.d { - // map header, size 2 - // string "k" - o = append(o, 0x82, 0xa1, 0x6b) - o = msgp.AppendString(o, z.d[za0001].k) - // string "v" - o = append(o, 0xa1, 0x76) - o, err = msgp.AppendIntf(o, z.d[za0001].v) - if err != nil { - err = msgp.WrapError(err, "d", za0001, "v") - return - } - } - return -} - -// UnmarshalMsg implements msgp.Unmarshaler -func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "d": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "d") - return - } - if cap(z.d) >= int(zb0002) { - z.d = (z.d)[:zb0002] - } else { - z.d = make([]kv, zb0002) - } - for za0001 := range z.d { - var zb0003 uint32 - zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - for zb0003 > 0 { - zb0003-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - switch msgp.UnsafeString(field) { - case "k": - z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts) - if err != nil { - err = msgp.WrapError(err, "d", za0001, "k") - return - } - case "v": - z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "d", za0001, "v") - return - } - default: - bts, err = msgp.Skip(bts) - if err != nil { - err = msgp.WrapError(err, "d", za0001) - return - } - } - } - } - default: - bts, err = msgp.Skip(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - o = bts - return -} - -// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z *data) Msgsize() (s int) { - s = 1 + 2 + msgp.ArrayHeaderSize - for za0001 := range z.d { - s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v) - } - return -} - -// DecodeMsg implements msgp.Decodable -func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, err = dc.ReadMapKeyPtr() - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "k": - z.k, err = dc.ReadString() - if err != nil { - err = msgp.WrapError(err, "k") - return - } - case "v": - z.v, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "v") - return - } - default: - err = dc.Skip() - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - return -} - -// EncodeMsg implements msgp.Encodable -func (z kv) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 2 - // write "k" - err = en.Append(0x82, 0xa1, 0x6b) - if err != nil { - return - } - err = en.WriteString(z.k) - if err != nil { - err = msgp.WrapError(err, "k") - return - } - // write "v" - err = en.Append(0xa1, 0x76) - if err != nil { - return - } - err = en.WriteIntf(z.v) - if err != nil { - err = msgp.WrapError(err, "v") - return - } - return -} - -// MarshalMsg implements msgp.Marshaler -func (z kv) MarshalMsg(b []byte) (o []byte, err error) { - o = msgp.Require(b, z.Msgsize()) - // map header, size 2 - // string "k" - o = append(o, 0x82, 0xa1, 0x6b) - o = msgp.AppendString(o, z.k) - // string "v" - o = append(o, 0xa1, 0x76) - o, err = msgp.AppendIntf(o, z.v) - if err != nil { - err = msgp.WrapError(err, "v") - return - } - return -} - -// UnmarshalMsg implements msgp.Unmarshaler -func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "k": - z.k, bts, err = msgp.ReadStringBytes(bts) - if err != nil { - err = msgp.WrapError(err, "k") - return - } - case "v": - z.v, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "v") - return - } - default: - bts, err = msgp.Skip(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - o = bts - return -} - -// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z kv) Msgsize() (s int) { - s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v) - return -} diff --git a/middleware/session/session.go b/middleware/session/session.go index 15217382..b967967c 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -5,11 +5,13 @@ import ( "time" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/internal/gotiny" "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) type Session struct { + // sync.RWMutex id string // session id fresh bool // if new session ctx *fiber.Ctx // fiber context @@ -26,7 +28,7 @@ var sessionPool = sync.Pool{ func acquireSession() *Session { s := sessionPool.Get().(*Session) if s.data == nil { - s.data = new(data) + s.data = acquireData() } s.fresh = true return s @@ -115,6 +117,7 @@ func (s *Session) Regenerate() error { // Save will update the storage and client cookie func (s *Session) Save() error { + // Better safe than sorry if s.data == nil { return nil @@ -126,18 +129,19 @@ func (s *Session) Save() error { } // Convert data to bytes - data, err := s.data.MarshalMsg(nil) - if err != nil { - return err - } + mux.Lock() + data := gotiny.Marshal(&s.data) + mux.Unlock() // pass raw bytes with session id to provider if err := s.config.Storage.Set(s.id, data, s.config.Expiration); err != nil { return err } - // Create cookie with the session ID - s.setCookie() + // Create cookie with the session ID if fresh + if s.fresh { + s.setCookie() + } // Release session // TODO: It's not safe to use the Session after called Save() diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index de58efcd..018f66d1 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -67,6 +67,108 @@ func Test_Session(t *testing.T) { utils.AssertEqual(t, 36, len(id)) } +// go test -run Test_Session_Types +func Test_Session_Types(t *testing.T) { + t.Parallel() + + // session store + store := New() + + // fiber instance + app := fiber.New() + + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // set cookie + ctx.Request().Header.SetCookie(store.CookieName, "123") + + // get session + sess, err := store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, sess.Fresh()) + + type User struct { + Name string + } + var vuser = User{ + Name: "John", + } + // set value + var vbool bool = true + var vstring string = "str" + var vint int = 13 + var vint8 int8 = 13 + var vint16 int16 = 13 + var vint32 int32 = 13 + var vint64 int64 = 13 + var vuint uint = 13 + var vuint8 uint8 = 13 + var vuint16 uint16 = 13 + var vuint32 uint32 = 13 + var vuint64 uint64 = 13 + var vuintptr uintptr = 13 + var vbyte byte = 'k' + var vrune rune = 'k' + var vfloat32 float32 = 13 + var vfloat64 float64 = 13 + var vcomplex64 complex64 = 13 + var vcomplex128 complex128 = 13 + sess.Set("vuser", vuser) + sess.Set("vbool", vbool) + sess.Set("vstring", vstring) + sess.Set("vint", vint) + sess.Set("vint8", vint8) + sess.Set("vint16", vint16) + sess.Set("vint32", vint32) + sess.Set("vint64", vint64) + sess.Set("vuint", vuint) + sess.Set("vuint8", vuint8) + sess.Set("vuint16", vuint16) + sess.Set("vuint32", vuint32) + sess.Set("vuint32", vuint32) + sess.Set("vuint64", vuint64) + sess.Set("vuintptr", vuintptr) + sess.Set("vbyte", vbyte) + sess.Set("vrune", vrune) + sess.Set("vfloat32", vfloat32) + sess.Set("vfloat64", vfloat64) + sess.Set("vcomplex64", vcomplex64) + sess.Set("vcomplex128", vcomplex128) + + // save session + err = sess.Save() + utils.AssertEqual(t, nil, err) + + // get session + sess, err = store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, false, sess.Fresh()) + + // get value + utils.AssertEqual(t, vuser, sess.Get("vuser").(User)) + utils.AssertEqual(t, vbool, sess.Get("vbool").(bool)) + utils.AssertEqual(t, vstring, sess.Get("vstring").(string)) + utils.AssertEqual(t, vint, sess.Get("vint").(int)) + utils.AssertEqual(t, vint8, sess.Get("vint8").(int8)) + utils.AssertEqual(t, vint16, sess.Get("vint16").(int16)) + utils.AssertEqual(t, vint32, sess.Get("vint32").(int32)) + utils.AssertEqual(t, vint64, sess.Get("vint64").(int64)) + utils.AssertEqual(t, vuint, sess.Get("vuint").(uint)) + utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8)) + utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16)) + utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32)) + utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64)) + utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr)) + utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte)) + utils.AssertEqual(t, vrune, sess.Get("vrune").(rune)) + utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32)) + utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64)) + utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64)) + utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128)) +} + // go test -run Test_Session_Store_Reset func Test_Session_Store_Reset(t *testing.T) { t.Parallel() diff --git a/middleware/session/store.go b/middleware/session/store.go index 2513f94f..17f24732 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -1,7 +1,10 @@ package session import ( + "sync" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/internal/gotiny" "github.com/gofiber/fiber/v2/internal/storage/memory" ) @@ -9,6 +12,8 @@ type Store struct { Config } +var mux sync.Mutex + // Storage ErrNotExist var errNotExist = "key does not exist" @@ -49,9 +54,9 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) { raw, err := s.Storage.Get(id) // Unmashal if we found data if err == nil { - if _, err = sess.data.UnmarshalMsg(raw); err != nil { - return nil, err - } + mux.Lock() + gotiny.Unmarshal(raw, &sess.data) + mux.Unlock() sess.fresh = false } else if raw != nil && err.Error() != "key does not exist" { return nil, err diff --git a/utils/assertions.go b/utils/assertions.go index a107a465..30b7fc02 100644 --- a/utils/assertions.go +++ b/utils/assertions.go @@ -16,17 +16,19 @@ import ( ) // AssertEqual checks if values are equal -func AssertEqual(t testing.TB, expected interface{}, actual interface{}, description ...string) { +func AssertEqual(t testing.TB, expected, actual interface{}, description ...string) { if reflect.DeepEqual(expected, actual) { return } + var aType = "" var bType = "" - if reflect.ValueOf(expected).IsValid() { - aType = reflect.TypeOf(expected).Name() + + if expected != nil { + aType = fmt.Sprintf("%s", reflect.TypeOf(expected)) } - if reflect.ValueOf(actual).IsValid() { - bType = reflect.TypeOf(actual).Name() + if actual != nil { + bType = fmt.Sprintf("%s", reflect.TypeOf(actual)) } testName := "AssertEqual" @@ -40,13 +42,11 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip w := tabwriter.NewWriter(&buf, 0, 0, 5, ' ', 0) fmt.Fprintf(w, "\nTest:\t%s", testName) fmt.Fprintf(w, "\nTrace:\t%s:%d", filepath.Base(file), line) - fmt.Fprintf(w, "\nError:\tNot equal") - fmt.Fprintf(w, "\nExpect:\t%v\t[%s]", expected, aType) - fmt.Fprintf(w, "\nResult:\t%v\t[%s]", actual, bType) - if len(description) > 0 { fmt.Fprintf(w, "\nDescription:\t%s", description[0]) } + fmt.Fprintf(w, "\nExpect:\t%v\t(%s)", expected, aType) + fmt.Fprintf(w, "\nResult:\t%v\t(%s)", actual, bType) result := "" if err := w.Flush(); err != nil { @@ -54,6 +54,7 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip } else { result = buf.String() } + if t != nil { t.Fatal(result) } else {