From bf1edc77d70465b4097a59c08c581033d2033ac6 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 28 May 2022 19:40:33 -0500
Subject: [PATCH] fixed putting wrong size bufs

---
 internal/iobufpool/iobufpool.go               | 41 ++++++----
 internal/iobufpool/iobufpool_internal_test.go |  2 +-
 internal/iobufpool/iobufpool_test.go          | 75 +++++++++++++++----
 internal/nbbconn/nbbconn.go                   |  1 -
 4 files changed, 87 insertions(+), 32 deletions(-)

diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go
index 52c52f45..95fd21f5 100644
--- a/internal/iobufpool/iobufpool.go
+++ b/internal/iobufpool/iobufpool.go
@@ -14,26 +14,16 @@ func init() {
 	}
 }
 
-// Get gets a []byte with len >= size and len <= size*2.
+// Get gets a []byte of len size with cap <= size*2.
 func Get(size int) []byte {
-	i := poolIdx(size)
+	i := getPoolIdx(size)
 	if i >= len(pools) {
 		return make([]byte, size)
 	}
-	return pools[i].Get().([]byte)
+	return pools[i].Get().([]byte)[:size]
 }
 
-// Put returns buf to the pool.
-func Put(buf []byte) {
-	i := poolIdx(len(buf))
-	if i >= len(pools) {
-		return
-	}
-
-	pools[i].Put(buf)
-}
-
-func poolIdx(size int) int {
+func getPoolIdx(size int) int {
 	size--
 	size >>= minPoolExpOf2
 	i := 0
@@ -44,3 +34,26 @@ func poolIdx(size int) int {
 
 	return i
 }
+
+// Put returns buf to the pool.
+func Put(buf []byte) {
+	buf = buf[:cap(buf)]
+
+	i := putPoolIdx(len(buf))
+	if i < 0 {
+		return
+	}
+
+	pools[i].Put(buf)
+}
+
+func putPoolIdx(size int) int {
+	minPoolSize := 1 << minPoolExpOf2
+	for i := range pools {
+		if size == minPoolSize<<i {
+			return i
+		}
+	}
+
+	return -1
+}
diff --git a/internal/iobufpool/iobufpool_internal_test.go b/internal/iobufpool/iobufpool_internal_test.go
index 38b499f9..23e63c25 100644
--- a/internal/iobufpool/iobufpool_internal_test.go
+++ b/internal/iobufpool/iobufpool_internal_test.go
@@ -30,7 +30,7 @@ func TestPoolIdx(t *testing.T) {
 		{size: 8388609, expected: 16},
 	}
 	for _, tt := range tests {
-		idx := poolIdx(tt.size)
+		idx := getPoolIdx(tt.size)
 		assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
 	}
 }
diff --git a/internal/iobufpool/iobufpool_test.go b/internal/iobufpool/iobufpool_test.go
index 9ad7417d..51b08215 100644
--- a/internal/iobufpool/iobufpool_test.go
+++ b/internal/iobufpool/iobufpool_test.go
@@ -5,31 +5,74 @@ import (
 
 	"github.com/jackc/pgx/v5/internal/iobufpool"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
-func TestGet(t *testing.T) {
+func TestGetCap(t *testing.T) {
 	tests := []struct {
 		requestedLen int
-		expectedLen  int
+		expectedCap  int
 	}{
-		{requestedLen: 0, expectedLen: 256},
-		{requestedLen: 128, expectedLen: 256},
-		{requestedLen: 255, expectedLen: 256},
-		{requestedLen: 256, expectedLen: 256},
-		{requestedLen: 257, expectedLen: 512},
-		{requestedLen: 511, expectedLen: 512},
-		{requestedLen: 512, expectedLen: 512},
-		{requestedLen: 513, expectedLen: 1024},
-		{requestedLen: 1023, expectedLen: 1024},
-		{requestedLen: 1024, expectedLen: 1024},
-		{requestedLen: 33554431, expectedLen: 33554432},
-		{requestedLen: 33554432, expectedLen: 33554432},
+		{requestedLen: 0, expectedCap: 256},
+		{requestedLen: 128, expectedCap: 256},
+		{requestedLen: 255, expectedCap: 256},
+		{requestedLen: 256, expectedCap: 256},
+		{requestedLen: 257, expectedCap: 512},
+		{requestedLen: 511, expectedCap: 512},
+		{requestedLen: 512, expectedCap: 512},
+		{requestedLen: 513, expectedCap: 1024},
+		{requestedLen: 1023, expectedCap: 1024},
+		{requestedLen: 1024, expectedCap: 1024},
+		{requestedLen: 33554431, expectedCap: 33554432},
+		{requestedLen: 33554432, expectedCap: 33554432},
 
 		// Above 32 MiB skip the pool and allocate exactly the requested size.
-		{requestedLen: 33554433, expectedLen: 33554433},
+		{requestedLen: 33554433, expectedCap: 33554433},
 	}
 	for _, tt := range tests {
 		buf := iobufpool.Get(tt.requestedLen)
-		assert.Equalf(t, tt.expectedLen, len(buf), "requestedLen: %d", tt.requestedLen)
+		assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
+		assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
 	}
 }
+
+func TestPutHandlesWrongSizedBuffers(t *testing.T) {
+	for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
+		putBuf := make([]byte, putBufSize)
+		iobufpool.Put(putBuf)
+
+		tests := []struct {
+			requestedLen int
+			expectedCap  int
+		}{
+			{requestedLen: 0, expectedCap: 256},
+			{requestedLen: 128, expectedCap: 256},
+			{requestedLen: 255, expectedCap: 256},
+			{requestedLen: 256, expectedCap: 256},
+			{requestedLen: 257, expectedCap: 512},
+			{requestedLen: 511, expectedCap: 512},
+			{requestedLen: 512, expectedCap: 512},
+			{requestedLen: 513, expectedCap: 1024},
+			{requestedLen: 1023, expectedCap: 1024},
+			{requestedLen: 1024, expectedCap: 1024},
+			{requestedLen: 33554431, expectedCap: 33554432},
+			{requestedLen: 33554432, expectedCap: 33554432},
+
+			// Above 32 MiB skip the pool and allocate exactly the requested size.
+			{requestedLen: 33554433, expectedCap: 33554433},
+		}
+		for _, tt := range tests {
+			getBuf := iobufpool.Get(tt.requestedLen)
+			assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
+			assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
+		}
+	}
+}
+
+func TestPutGetBufferReuse(t *testing.T) {
+	buf := iobufpool.Get(4)
+	buf[0] = 1
+	iobufpool.Put(buf)
+	buf = iobufpool.Get(4)
+	require.Equal(t, byte(1), buf[0])
+}
diff --git a/internal/nbbconn/nbbconn.go b/internal/nbbconn/nbbconn.go
index ce567803..12204cd4 100644
--- a/internal/nbbconn/nbbconn.go
+++ b/internal/nbbconn/nbbconn.go
@@ -83,7 +83,6 @@ func (c *Conn) Write(b []byte) (n int, err error) {
 	}
 
 	buf := iobufpool.Get(len(b))
-	buf = buf[:len(b)]
 	copy(buf, b)
 	c.writeQueue.pushBack(buf)
 	return len(b), nil