mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Compare commits
703 Commits
v5.0.0-bet
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
fc334e4c75 | ||
|
3f5509fe98 | ||
|
15bca4a4e1 | ||
|
1d557f9116 | ||
|
de7fe81d78 | ||
|
d9eb089bd7 | ||
|
6be24eb08d | ||
|
07871c0a34 | ||
|
de806a11e7 | ||
|
ce13266e90 | ||
|
777e7e5cdf | ||
|
151bd026ec | ||
|
540fcaa9b9 | ||
|
3a248e3822 | ||
|
baca2d848a | ||
|
c911d86cff | ||
|
2bac99e2ae | ||
|
c92d0a9045 | ||
|
e9aad0fb0b | ||
|
9e7f38cd50 | ||
|
e779a5c072 | ||
|
ff9c26d851 | ||
|
0f77a2d028 | ||
|
ddd966f09f | ||
|
924834b5b4 | ||
|
9b15554c51 | ||
|
037e4cf9a2 | ||
|
04bcc0219d | ||
|
0e0a7d8344 | ||
|
63422c7d6c | ||
|
5c1fbf4806 | ||
|
05fe5f8b05 | ||
|
70c9a147a2 | ||
|
6603ddfbe4 | ||
|
70f7cad222 | ||
|
6bf1b0b1b9 | ||
|
14bda65a0c | ||
|
9e3c4fb40f | ||
|
05e72a5ab1 | ||
|
47d631e34b | ||
|
58b05f567c | ||
|
dcb7193669 | ||
|
1abf7d9050 | ||
|
b5efc90a32 | ||
|
a26c93551f | ||
|
2100e1da46 | ||
|
2d21a2b80d | ||
|
5f33ee5f07 | ||
|
228cfffc20 | ||
|
a5353af354 | ||
|
0bc29e3000 | ||
|
9cce05944a | ||
|
9c0ad690a9 | ||
|
03f08abda3 | ||
|
2c1b1c389a | ||
|
329cb45913 | ||
|
c96a55f8c0 | ||
|
e87760682f | ||
|
f681632c68 | ||
|
3c640a44b6 | ||
|
de3f868c1d | ||
|
5424d3c873 | ||
|
42d3d00734 | ||
|
cdc672cf3f | ||
|
52e2858629 | ||
|
e352784fed | ||
|
c2175fe46e | ||
|
659823f8f3 | ||
|
ca04098fab | ||
|
4ff0a454e0 | ||
|
00b86ca3db | ||
|
61a0227241 | ||
|
2190a8e0d1 | ||
|
6e9fa42fef | ||
|
6d9e6a726e | ||
|
02e387ea64 | ||
|
e452f80b1d | ||
|
da0315d1a4 | ||
|
120c89fe0d | ||
|
057937db27 | ||
|
47cbd8edb8 | ||
|
90a77b13b2 | ||
|
59d6aa87b9 | ||
|
39ffc8b7a4 | ||
|
c4c1076d28 | ||
|
4293b25262 | ||
|
ea1e13a660 | ||
|
58d4c0c94f | ||
|
1752f7b4c1 | ||
|
ee718a110d | ||
|
546ad2f4e2 | ||
|
efc2c9ff44 | ||
|
aabed18db8 | ||
|
afa974fb05 | ||
|
12b37f3218 | ||
|
bcf3fbd780 | ||
|
f7c3d190ad | ||
|
473a241b96 | ||
|
311f72afdc | ||
|
877111ceeb | ||
|
dc3aea06b5 | ||
|
e5d321f920 | ||
|
17cd36818c | ||
|
24fbe353ed | ||
|
3a1593b25b | ||
|
9d851d7c98 | ||
|
dacffdc7e2 | ||
|
bc7c840770 | ||
|
043685147f | ||
|
25329273da | ||
|
ad87d47089 | ||
|
7cf7bc6054 | ||
|
76593f37f7 | ||
|
3e6c719698 | ||
|
5ee33320c6 | ||
|
ac0b46f2f9 | ||
|
e3c81cc153 | ||
|
4b7e9942b2 | ||
|
b9e2b20fb1 | ||
|
06a0abb75e | ||
|
c76a650f75 | ||
|
f57b2854f8 | ||
|
5c9b565116 | ||
|
2ec900454b | ||
|
8723855d95 | ||
|
3f84e891de | ||
|
cc05954369 | ||
|
123b59a57e | ||
|
10e11952bd | ||
|
32a6b1b200 | ||
|
f0783c6fbe | ||
|
0290507ff2 | ||
|
8f8470edaf | ||
|
a95cfbb433 | ||
|
7803ec3661 | ||
|
64ca07e31b | ||
|
fd0c65478e | ||
|
672c4a3a24 | ||
|
f8a5a5c9e3 | ||
|
ab36c2c0dd | ||
|
ce66b1dae4 | ||
|
d1205a6dbc | ||
|
97d20ccfad | ||
|
e9bd382c51 | ||
|
603f2337d6 | ||
|
035bbbe0cb | ||
|
73bbced270 | ||
|
4171f554d4 | ||
|
b197994b1f | ||
|
57fd684068 | ||
|
926913ad66 | ||
|
b9f77cb1b3 | ||
|
218c15a4eb | ||
|
4f7e19d67d | ||
|
0cbc5db39d | ||
|
5747f37d9c | ||
|
d6fc8b02b4 | ||
|
c457de62c9 | ||
|
216049c62b | ||
|
a68e14fe5a | ||
|
ea9610f672 | ||
|
7af618e423 | ||
|
3f270eec7d | ||
|
8e46d2117c | ||
|
9530aea47b | ||
|
a8aaa37363 | ||
|
67aa0e5a65 | ||
|
96791c88cd | ||
|
71a8e53574 | ||
|
13e212430d | ||
|
b25d092d20 | ||
|
7fceb64dee | ||
|
7a35585143 | ||
|
a787630988 | ||
|
37681a4f48 | ||
|
c7b9dc0e00 | ||
|
f007d84675 | ||
|
3563a2b048 | ||
|
b770252a3b | ||
|
c64fa0f0f2 | ||
|
dced53f796 | ||
|
161ce73ec1 | ||
|
fa57a20518 | ||
|
dd71547340 | ||
|
47977703e1 | ||
|
a764746906 | ||
|
6b9ff972a4 | ||
|
c407c42692 | ||
|
9907b874c2 | ||
|
ec557e87d5 | ||
|
9f4a264f89 | ||
|
572d7fff32 | ||
|
b4911f1da7 | ||
|
29751194ef | ||
|
c1f4cbb5cd | ||
|
24c0a5e8ff | ||
|
9ca9203afb | ||
|
79cab4640f | ||
|
6ea2d248a3 | ||
|
c1075bfff0 | ||
|
cf6074fe5c | ||
|
13beb380f5 | ||
|
fec45c802b | ||
|
3b7fa4ce87 | ||
|
732889728f | ||
|
e1b90cf620 | ||
|
2a36a7032e | ||
|
ded01c0cd9 | ||
|
532bf8f583 | ||
|
169067a364 | ||
|
659525c961 | ||
|
4dd1810d8b | ||
|
25914e21f3 | ||
|
19fcb54564 | ||
|
a39632db43 | ||
|
c05cce7d41 | ||
|
0080acf318 | ||
|
c81bba8690 | ||
|
523411a3fb | ||
|
a966716860 | ||
|
cf50c60869 | ||
|
8db971660e | ||
|
48cdd7bab0 | ||
|
579a320c1c | ||
|
01d649b2bf | ||
|
48ae1f4b2c | ||
|
e4f72071f8 | ||
|
6f0deff015 | ||
|
8649231bb3 | ||
|
33360ab479 | ||
|
c31619d08b | ||
|
ec9bb2ace7 | ||
|
93a579754b | ||
|
42c9e9070a | ||
|
60a01d044a | ||
|
8f69e45a53 | ||
|
ec98406207 | ||
|
8db0f280fb | ||
|
fc416d237a | ||
|
a3d9120636 | ||
|
78b22c3d2f | ||
|
221ad1b84c | ||
|
b6e5548341 | ||
|
1b6227af11 | ||
|
c1fce377ee | ||
|
7fd6f2a4f5 | ||
|
78a0a2bf41 | ||
|
a17f064492 | ||
|
49b6aad319 | ||
|
0cc4c14e62 | ||
|
da6f2c98f2 | ||
|
c543134753 | ||
|
20344dfae8 | ||
|
adbb38f298 | ||
|
c1b0a01ca7 | ||
|
88dfc22ae4 | ||
|
2e84dccaf5 | ||
|
d149d3fe5c | ||
|
046f497efb | ||
|
8896bd6977 | ||
|
85f15c4b3c | ||
|
654dcab93e | ||
|
5c63f646f8 | ||
|
6f8f6ede6c | ||
|
576b6c88f6 | ||
|
7caa448ac8 | ||
|
832b4f9771 | ||
|
fd4411453f | ||
|
34da2fed95 | ||
|
7b5fcac465 | ||
|
0819a17da8 | ||
|
bf1c1d7848 | ||
|
0fa533386c | ||
|
c90f82a4e3 | ||
|
a57bb8caea | ||
|
517c654e2c | ||
|
a4ca0917da | ||
|
0c35c9e630 | ||
|
b7de418d46 | ||
|
b99e2bb7e0 | ||
|
52f2151422 | ||
|
dfb6489612 | ||
|
9346d48035 | ||
|
1fdd17041a | ||
|
f654d61d79 | ||
|
5d26bbefd8 | ||
|
44768b5a01 | ||
|
6f2ce92356 | ||
|
4367ee0598 | ||
|
d2c9ebc2ef | ||
|
0c7acf9481 | ||
|
cbc5a7055f | ||
|
4c14caae07 | ||
|
22fe50149b | ||
|
dfd198003a | ||
|
603c8c1e90 | ||
|
9ab9e3c40b | ||
|
2daeb8dc5f | ||
|
df3c5f4df8 | ||
|
b1631e8e35 | ||
|
ba05097642 | ||
|
384fe7775c | ||
|
20bf953a17 | ||
|
12582a0fd4 | ||
|
905f252667 | ||
|
9927e14bbf | ||
|
95b2f85e60 | ||
|
913e4c8487 | ||
|
31321c2017 | ||
|
319c3172f2 | ||
|
4678e69599 | ||
|
89d699c2e8 | ||
|
7ebced92b5 | ||
|
94e56e61ba | ||
|
9103457384 | ||
|
9782306287 | ||
|
7d5a3969d0 | ||
|
e5015e2fac | ||
|
4dbd57a7ed | ||
|
0570b0e196 | ||
|
df5d00eb60 | ||
|
d38dd85756 | ||
|
9b6d3809d6 | ||
|
b4d72d4fce | ||
|
ccdd85a5eb | ||
|
96f5f9cd95 | ||
|
d3fb6e00da | ||
|
cf6ef75f91 | ||
|
7a4bb7edb5 | ||
|
6f7400f428 | ||
|
304697de36 | ||
|
5d0f904831 | ||
|
6ca3d8ed4e | ||
|
81ddcfdefb | ||
|
45f807fdb4 | ||
|
8a09979417 | ||
|
7a2b93323c | ||
|
1484fec57f | ||
|
3957163808 | ||
|
7fc908a5f2 | ||
|
0f0d236599 | ||
|
c6c50110db | ||
|
91530db629 | ||
|
24ed0e4257 | ||
|
163eb68866 | ||
|
a61517a83b | ||
|
d93f31b8fa | ||
|
cf72a00f52 | ||
|
c08cc72306 | ||
|
7de53a958b | ||
|
bbe2653bc5 | ||
|
4e7aa59d64 | ||
|
b301530a5f | ||
|
f42824cab3 | ||
|
18856482c4 | ||
|
639691c0ab | ||
|
3e716c4b06 | ||
|
51ade172e5 | ||
|
3d4540aa1b | ||
|
389931396e | ||
|
9ee7d29cf9 | ||
|
a7375cc503 | ||
|
d43bd349c1 | ||
|
5c6cf62b53 | ||
|
d17440d5c7 | ||
|
4c60839c48 | ||
|
e9087eacb8 | ||
|
d626dfe94e | ||
|
1a9b2a53a5 | ||
|
8fb309c631 | ||
|
f4533dc906 | ||
|
4091eedf03 | ||
|
87d771ef9c | ||
|
492283b90b | ||
|
e665f74c99 | ||
|
f90e86fd8d | ||
|
88b49d48f6 | ||
|
2506cf3666 | ||
|
d58fe2d53c | ||
|
ef9e26a5d5 | ||
|
6703484a0d | ||
|
c513e2e435 | ||
|
f47f0cf823 | ||
|
bd3e0d422c | ||
|
2f6fcf8eb0 | ||
|
038fc448c1 | ||
|
95aa87f2e8 | ||
|
f512b9688b | ||
|
05440f9d3f | ||
|
e0c70201dc | ||
|
524f661136 | ||
|
507a9e9ad3 | ||
|
0328d314ea | ||
|
cd46cdd450 | ||
|
2bf5a61401 | ||
|
dc94db6b3d | ||
|
b68e7b2a68 | ||
|
1dd69f86a1 | ||
|
8e6cf8f3a5 | ||
|
91cba90e8d | ||
|
0d14b87140 | ||
|
e79efdacf9 | ||
|
20a40120ed | ||
|
aa263d4352 | ||
|
7fccc604af | ||
|
34f17a6048 | ||
|
74ab538d2a | ||
|
7c386112e3 | ||
|
9a5ead9048 | ||
|
737b5af236 | ||
|
f20070650f | ||
|
e5db6a0467 | ||
|
5b7cc8e215 | ||
|
bc8b1ca320 | ||
|
2de94187f5 | ||
|
07670dddca | ||
|
d48d36dc02 | ||
|
eb2807bda5 | ||
|
b1f8055584 | ||
|
461b9fa36e | ||
|
45520d5a11 | ||
|
90f9aad67f | ||
|
5f28621394 | ||
|
c542df4fb4 | ||
|
34eddf9983 | ||
|
5d4f9018bf | ||
|
482e56a79b | ||
|
3ea2f57d8b | ||
|
26c79eb215 | ||
|
85136a8efe | ||
|
4410fc0a65 | ||
|
9cfdd21f1c | ||
|
4d643b75f5 | ||
|
490f70fc5f | ||
|
1b68b5970e | ||
|
ee04d4a74d | ||
|
d9560c78b8 | ||
|
608f39f426 | ||
|
229d2aaa49 | ||
|
b4314ddaf7 | ||
|
28bd5b3843 | ||
|
fb47e1abbb | ||
|
c861bce438 | ||
|
46d91255b0 | ||
|
ef363b59ab | ||
|
bad6b36c47 | ||
|
33d4fa0fa6 | ||
|
30d63caa6a | ||
|
b0fa429fd0 | ||
|
32c7858e61 | ||
|
c7733fe52e | ||
|
9720d0d63f | ||
|
5f6636d028 | ||
|
a1a97a7ca8 | ||
|
0ec512b504 | ||
|
f93b42b6ac | ||
|
9f00b6f750 | ||
|
4b9aa7c4f2 | ||
|
2c1973de46 | ||
|
b3739c1289 | ||
|
70a200cff4 | ||
|
c1c67e4e58 | ||
|
9de41fac75 | ||
|
11d892dfcf | ||
|
0292edecb0 | ||
|
eab316e200 | ||
|
8ceef73b84 | ||
|
bbcc4fc0b8 | ||
|
cead918e18 | ||
|
7f2bb9595f | ||
|
d8b38b28be | ||
|
2a86501e86 | ||
|
f59e8bf555 | ||
|
c27b9b49ea | ||
|
6defa2a607 | ||
|
a23a423f55 | ||
|
09371981f9 | ||
|
67f2a41587 | ||
|
2cf1541bb9 | ||
|
84eb2e460a | ||
|
847f888631 | ||
|
f72a147db3 | ||
|
8b7c699b8f | ||
|
215ffafc74 | ||
|
5eeaa201d9 | ||
|
be79f1c8f5 | ||
|
ca022267db | ||
|
2a653b4a8d | ||
|
7af80ae8a6 | ||
|
7555c43033 | ||
|
193bab416f | ||
|
e9d64ec29d | ||
|
2f1bba09c4 | ||
|
d829073b2f | ||
|
48da6435a5 | ||
|
34e3013153 | ||
|
009a377028 | ||
|
e05abb83ec | ||
|
89475c4c91 | ||
|
c3d62c8783 | ||
|
1298a835bc | ||
|
b2b4fbcf57 | ||
|
3db7d1774e | ||
|
a83faa67f5 | ||
|
8b5e8d9d89 | ||
|
9ae852eb58 | ||
|
19039e6dd1 | ||
|
0dbb0a52ab | ||
|
087b8b2ba8 | ||
|
c09ddaf440 | ||
|
80eb6e1859 | ||
|
7ec6ee7b0a | ||
|
6105ca5073 | ||
|
8f46c75e73 | ||
|
38e09bda4c | ||
|
9567297815 | ||
|
42d327f660 | ||
|
f17c743c3c | ||
|
a6ace8969b | ||
|
c2e278e5d4 | ||
|
c5daa3a814 | ||
|
f5d2da7a19 | ||
|
b8262ace75 | ||
|
2100a64dbe | ||
|
4484831550 | ||
|
1f43e2e490 | ||
|
b707faea8f | ||
|
255f16b00f | ||
|
a47e836471 | ||
|
5cd8468b99 | ||
|
fa5fbed497 | ||
|
190c05cc24 | ||
|
c875abea84 | ||
|
98543e0354 | ||
|
32c29a6edd | ||
|
9963c32d4f | ||
|
6bc327b3ce | ||
|
f46d35610e | ||
|
cf78472ce5 | ||
|
766d2bba4f | ||
|
384a581e99 | ||
|
898891a6ee | ||
|
7019ed1edf | ||
|
eee854fb06 | ||
|
bc754291c1 | ||
|
2c7d86a543 | ||
|
42a47194a2 | ||
|
7941518809 | ||
|
f839d501a7 | ||
|
f581584148 | ||
|
e48e7a7189 | ||
|
516300aabf | ||
|
62a7e19a04 | ||
|
672431c0bd | ||
|
7c0c7dc01e | ||
|
fcec008a4c | ||
|
d993cfa8fd | ||
|
a95cfe5cc5 | ||
|
c46d792c93 | ||
|
37c6f97b11 | ||
|
74f9b9f0a4 | ||
|
5177e1a8df | ||
|
d4fcd4a897 | ||
|
c514b2e0c3 | ||
|
e66ad1bcec | ||
|
c4ac6d810f | ||
|
456a242f5c | ||
|
d737852654 | ||
|
29ad306e47 | ||
|
f42af35884 | ||
|
11fa083a0d | ||
|
1ce3e0384a | ||
|
e58381ac94 | ||
|
279c3c0a20 | ||
|
17f8f7af63 | ||
|
f0a73424b1 | ||
|
88b373f9ee | ||
|
8e2de2fefa | ||
|
24c53259f8 | ||
|
8eb062f588 | ||
|
fbfafb3edf | ||
|
174224fa07 | ||
|
8ad1394f4c | ||
|
56633b3d51 | ||
|
ba4bbf92af | ||
|
b4d2eae777 | ||
|
3520c2ea43 | ||
|
c94c47f584 | ||
|
8678ed560f | ||
|
05924a9d6b | ||
|
2e9e2865f9 | ||
|
14be51536b | ||
|
1376a2c0ed | ||
|
932f676cfd | ||
|
5b6fb75669 | ||
|
b265fedd75 | ||
|
871f14e43b | ||
|
071d1c9467 | ||
|
29109487ec | ||
|
daf570c752 | ||
|
a86acf61e0 | ||
|
a968ce3437 | ||
|
39676004de | ||
|
6f90866f58 | ||
|
d8c04249d1 | ||
|
7fd064ab80 | ||
|
0013f6c7ca | ||
|
95498282bb | ||
|
6e77e0a09d | ||
|
1f0fd66623 | ||
|
45aeaed20a | ||
|
a2da398dff | ||
|
be419e25b4 | ||
|
dd07e24a6c | ||
|
0920c79b02 | ||
|
268af3903c | ||
|
4d711aaa73 | ||
|
dc85718658 | ||
|
6b52e0b5e0 | ||
|
9eaeb51e30 | ||
|
8b2ac8c18f | ||
|
05e9234c2e | ||
|
97d1012f42 | ||
|
6bedfa7def | ||
|
55b5067ddd | ||
|
1ec3816a20 | ||
|
c9c166b8b2 | ||
|
9a207178f6 | ||
|
3feeddd9f1 | ||
|
72c89108ad | ||
|
c130b2d74a | ||
|
7d3b9c1e44 | ||
|
6515e183ff | ||
|
e35041372d | ||
|
6fabd8f5b1 | ||
|
c00fb5d2a1 | ||
|
55d5d036c0 | ||
|
987de3874e | ||
|
3ad9995dfe | ||
|
3e825ec898 | ||
|
ba100785cc | ||
|
48b4807b33 | ||
|
6e40968cfc | ||
|
11e5f68ff6 | ||
|
7a9e70d1e0 | ||
|
f2e7c8144d | ||
|
aff180b192 | ||
|
a581124dea | ||
|
c4407fb36e | ||
|
094ad9c9d8 | ||
|
af0b896290 | ||
|
5655f9d593 | ||
|
f803c790d0 | ||
|
222e3b37bc | ||
|
89f69aaea9 | ||
|
63ae730fe8 | ||
|
305c4ddbc7 | ||
|
fb83fb0cc3 | ||
|
c48dd7e1f8 | ||
|
cd8b29b0fe | ||
|
0aa681f3a3 | ||
|
335c8621ff | ||
|
ac9d4f4d96 | ||
|
72e4b88e56 | ||
|
639fb28846 | ||
|
d7c7ddc594 | ||
|
4fc4f9a603 | ||
|
23a59d68fc | ||
|
5a055434f2 | ||
|
1a314bda3b | ||
|
4f1a8084f1 | ||
|
a05fb80b8a | ||
|
90b69c0ee0 | ||
|
ee2622a8e6 | ||
|
d42b399be3 | ||
|
f015ced1bf | ||
|
782133158f | ||
|
dfce986bb5 | ||
|
f8d088cfb6 | ||
|
f5cdf0d383 | ||
|
72fe594942 | ||
|
bce26b85d1 | ||
|
bb6c997102 | ||
|
fe3a4f3150 | ||
|
2e73d1e8ee | ||
|
0d5d8e0137 | ||
|
ae65a8007b | ||
|
dbee461dc9 | ||
|
ef5655c563 | ||
|
15f8e6323e | ||
|
e3406d95f9 | ||
|
067771b2e6 | ||
|
8eae4a2a3e | ||
|
faabb0696f | ||
|
1d748d9bbf | ||
|
c842802d65 | ||
|
7c6a31f9d2 | ||
|
02d9a5acd8 | ||
|
8256ab147f | ||
|
906f709e0c | ||
|
33b782a96d | ||
|
1453cd4b97 | ||
|
6871a0c4a6 |
4
.github/ISSUE_TEMPLATE/bug_report.md
vendored
4
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -23,7 +23,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -37,6 +37,8 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
|
168
.github/workflows/ci.yml
vendored
168
.github/workflows/ci.yml
vendored
@ -2,87 +2,155 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, v5-dev ]
|
||||
branches: [master]
|
||||
pull_request:
|
||||
branches: [master]
|
||||
|
||||
jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: [1.18]
|
||||
pg-version: [10, 11, 12, 13, 14, cockroachdb]
|
||||
go-version: ["1.23", "1.24"]
|
||||
pg-version: [13, 14, 15, 16, 17, cockroachdb]
|
||||
include:
|
||||
- pg-version: 10
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 11
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 12
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 13
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 14
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 15
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 16
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 17
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: cockroachdb
|
||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Setup database server for testing
|
||||
run: ci/setup_test.bash
|
||||
env:
|
||||
PGVERSION: ${{ matrix.pg-version }}
|
||||
|
||||
# - name: Setup upterm session
|
||||
# uses: lhotari/action-upterm@v1
|
||||
# with:
|
||||
# ## limits ssh access and adds the ssh public key for the user which triggered the workflow
|
||||
# limit-access-to-actor: true
|
||||
# env:
|
||||
# PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
# PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
# PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||
# PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
|
||||
# PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||
# PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
# PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
|
||||
# PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
|
||||
|
||||
- name: Check formatting
|
||||
run: |
|
||||
gofmt -l -s -w .
|
||||
git status
|
||||
git diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./...
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
|
||||
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
|
||||
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
|
||||
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
|
||||
PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
|
||||
|
||||
test-windows:
|
||||
name: Test Windows
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.23", "1.24"]
|
||||
|
||||
steps:
|
||||
- name: Setup PostgreSQL
|
||||
id: postgres
|
||||
uses: ikalnytskyi/action-setup-postgres@v4
|
||||
with:
|
||||
database: pgx_test
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
- name: Initialize test database
|
||||
run: |
|
||||
psql -f testsetup/postgresql_setup.sql pgx_test
|
||||
env:
|
||||
PGSERVICE: ${{ steps.postgres.outputs.service-name }}
|
||||
shell: bash
|
||||
|
||||
- name: Test
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -22,3 +22,6 @@ _testmain.go
|
||||
*.exe
|
||||
|
||||
.envrc
|
||||
/.testdb
|
||||
|
||||
.DS_Store
|
||||
|
@ -1,9 +0,0 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
303
CHANGELOG.md
303
CHANGELOG.md
@ -1,4 +1,286 @@
|
||||
# Unreleased v5
|
||||
# 5.7.5 (May 17, 2025)
|
||||
|
||||
* Support sslnegotiation connection option (divyam234)
|
||||
* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487.
|
||||
* TraceLog now logs Acquire and Release at the debug level (dave sinclair)
|
||||
* Add support for PGTZ environment variable
|
||||
* Add support for PGOPTIONS environment variable
|
||||
* Unpin memory used by Rows quicker
|
||||
* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit.
|
||||
|
||||
# 5.7.4 (March 24, 2025)
|
||||
|
||||
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
|
||||
|
||||
# 5.7.3 (March 21, 2025)
|
||||
|
||||
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
|
||||
* Improve SQL sanitizer performance (ninedraft)
|
||||
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
|
||||
* Fix Values() for xml type always returning nil instead of []byte
|
||||
* Add ability to send Flush message in pipeline mode (zenkovev)
|
||||
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
|
||||
* Better error messages when scanning structs (logicbomb)
|
||||
* Fix handling of error on batch write (bonnefoa)
|
||||
* Match libpq's connection fallback behavior more closely (felix-roehrich)
|
||||
* Add MinIdleConns to pgxpool (djahandarie)
|
||||
|
||||
# 5.7.2 (December 21, 2024)
|
||||
|
||||
* Fix prepared statement already exists on batch prepare failure
|
||||
* Add commit query to tx options (Lucas Hild)
|
||||
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
|
||||
* Add message body size limits in frontend and backend (zene)
|
||||
* Add xid8 type
|
||||
* Ensure planning encodes and scans cannot infinitely recurse
|
||||
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
||||
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
||||
* Update golang.org/x/crypto
|
||||
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
|
||||
|
||||
# 5.7.1 (September 10, 2024)
|
||||
|
||||
* Fix data race in tracelog.TraceLog
|
||||
* Update puddle to v2.2.2. This removes the import of nanotime via linkname.
|
||||
* Update golang.org/x/crypto and golang.org/x/text
|
||||
|
||||
# 5.7.0 (September 7, 2024)
|
||||
|
||||
* Add support for sslrootcert=system (Yann Soubeyrand)
|
||||
* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell)
|
||||
* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda)
|
||||
* Add MultiTrace (Stepan Rabotkin)
|
||||
* Add TraceLogConfig with customizable TimeKey (stringintech)
|
||||
* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin)
|
||||
* Support scanning binary formatted uint32 into string / TextScanner (jennifersp)
|
||||
* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce)
|
||||
* Update pgservicefile - fixes panic when parsing invalid file
|
||||
* Better error message when reading past end of batch
|
||||
* Don't print url when url.Parse returns an error (Kevin Biju)
|
||||
* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler)
|
||||
* Fix: Scan and encode types with underlying types of arrays
|
||||
|
||||
# 5.6.0 (May 25, 2024)
|
||||
|
||||
* Add StrictNamedArgs (Tomas Zahradnicek)
|
||||
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
|
||||
* Add SeverityUnlocalized field to PgError / Notice
|
||||
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
|
||||
* Allow customizing context canceled behavior for pgconn
|
||||
* Add ScanLocation to pgtype.Timestamp[tz]Codec
|
||||
* Add custom data to pgconn.PgConn
|
||||
* Fix ResultReader.Read() to handle nil values
|
||||
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
|
||||
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
|
||||
* Failed connection attempts include all errors
|
||||
* Optimize LargeObject.Read (Mitar)
|
||||
* Add tracing for connection acquire and release from pool (ngavinsir)
|
||||
* Fix encode driver.Valuer not called when nil
|
||||
* Add support for custom JSON marshal and unmarshal (Mitar)
|
||||
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
|
||||
|
||||
# 5.5.5 (March 9, 2024)
|
||||
|
||||
Use spaces instead of parentheses for SQL sanitization.
|
||||
|
||||
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
|
||||
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
|
||||
|
||||
# 5.5.4 (March 4, 2024)
|
||||
|
||||
Fix CVE-2024-27304
|
||||
|
||||
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
|
||||
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
|
||||
attacker's control.
|
||||
|
||||
Thanks to Paul Gerste for reporting this issue.
|
||||
|
||||
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
|
||||
* Fix simple protocol encoding of json.RawMessage
|
||||
* Fix *Pipeline.getResults should close pipeline on error
|
||||
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
|
||||
* Fix deallocation of invalidated cached statements in a transaction
|
||||
* Handle invalid sslkey file
|
||||
* Fix scan float4 into sql.Scanner
|
||||
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
|
||||
|
||||
# 5.5.3 (February 3, 2024)
|
||||
|
||||
* Fix: prepared statement already exists
|
||||
* Improve CopyFrom auto-conversion of text-ish values
|
||||
* Add ltree type support (Florent Viel)
|
||||
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
|
||||
* Add AppendRows function (Edoardo Spadolini)
|
||||
* Optimize convert UUID [16]byte to string (Kirill Malikov)
|
||||
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
|
||||
|
||||
# 5.5.2 (January 13, 2024)
|
||||
|
||||
* Allow NamedArgs to start with underscore
|
||||
* pgproto3: Maximum message body length support (jeremy.spriet)
|
||||
* Upgrade golang.org/x/crypto to v0.17.0
|
||||
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
|
||||
* Fix: update description cache after exec prepare (James Hartig)
|
||||
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
|
||||
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
|
||||
* Add OnPgError for easier centralized error handling (James Hartig)
|
||||
|
||||
# 5.5.1 (December 9, 2023)
|
||||
|
||||
* Add CopyFromFunc helper function. (robford)
|
||||
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
|
||||
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
|
||||
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
|
||||
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
|
||||
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
|
||||
|
||||
# 5.5.0 (November 4, 2023)
|
||||
|
||||
* Add CollectExactlyOneRow. (Julien GOTTELAND)
|
||||
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
|
||||
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
|
||||
* Statement cache now uses deterministic, stable statement names.
|
||||
* database/sql prepared statement names are deterministically generated.
|
||||
* Fix: SendBatch wasn't respecting context cancellation.
|
||||
* Fix: Timeout error from pipeline is now normalized.
|
||||
* Fix: database/sql encoding json.RawMessage to []byte.
|
||||
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
|
||||
* stdlib: Use Ping instead of CheckConn in ResetSession
|
||||
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
|
||||
|
||||
# 5.4.3 (August 5, 2023)
|
||||
|
||||
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
|
||||
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
|
||||
* Fix: pgxpool: background health check cannot overflow pool
|
||||
* Fix: Check for nil in defer when sending batch (recover properly from panic)
|
||||
* Fix: json scan of non-string pointer to pointer
|
||||
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
|
||||
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
|
||||
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
|
||||
* Optimization: improve text format numeric parsing (horpto)
|
||||
|
||||
# 5.4.2 (July 11, 2023)
|
||||
|
||||
* Fix: RowScanner errors are fatal to Rows
|
||||
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
|
||||
* Hstore text codec internal improvements (Evan Jones)
|
||||
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
|
||||
* Fix: Stop background reader as soon as possible.
|
||||
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
|
||||
|
||||
# 5.4.1 (June 18, 2023)
|
||||
|
||||
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
|
||||
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
|
||||
|
||||
# 5.4.0 (June 14, 2023)
|
||||
|
||||
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
|
||||
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
|
||||
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
|
||||
* CancelRequest: don't try to read the reply (Nicola Murino)
|
||||
* Fix: correctly handle bool type aliases (Wichert Akkerman)
|
||||
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
|
||||
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
|
||||
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
|
||||
* Fix: various hstore fixes and optimizations (Evan Jones)
|
||||
* Fix: RowToStructByPos with embedded unexported struct
|
||||
* Support different bool string representations (Lev Zakharov)
|
||||
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
|
||||
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
|
||||
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
|
||||
a callback.
|
||||
* Fix: scanning a table type into a struct
|
||||
* Fix: scan array of record to pointer to slice of struct
|
||||
* Fix: handle null for json (Cemre Mengu)
|
||||
* Batch Query callback is called even when there is an error
|
||||
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
|
||||
|
||||
# 5.3.1 (February 27, 2023)
|
||||
|
||||
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
|
||||
* Fix: sql.Scanner not being used in certain cases
|
||||
* Add text format jsonpath support
|
||||
* Fix: fake non-blocking read adaptive wait time
|
||||
|
||||
# 5.3.0 (February 11, 2023)
|
||||
|
||||
* Fix: json values work with sql.Scanner
|
||||
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
|
||||
* Fix: support scan into single dimensional arrays
|
||||
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
|
||||
* Fix: driver.Value representation of bytea should be []byte not string
|
||||
* Fix: better handling of unregistered OIDs
|
||||
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
|
||||
* Fix: encode to json ignoring driver.Valuer
|
||||
* Support sql.Scanner on renamed base type
|
||||
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
|
||||
* Fix: connect with multiple hostnames when one can't be resolved
|
||||
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
|
||||
* Fix: scanning json column into **string
|
||||
* Multiple reductions in memory allocations
|
||||
* Fake non-blocking read adapts its max wait time
|
||||
* Improve CopyFrom performance and reduce memory usage
|
||||
* Fix: encode []any to array
|
||||
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
|
||||
* Support v4 and v5 stdlib in same program
|
||||
* Fix: text format array decoding with string of "NULL"
|
||||
* Prefer binary format for arrays
|
||||
|
||||
# 5.2.0 (December 5, 2022)
|
||||
|
||||
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
|
||||
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
|
||||
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
|
||||
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
|
||||
|
||||
# 5.1.1 (November 17, 2022)
|
||||
|
||||
* Fix simple query sanitizer where query text contains a Unicode replacement character.
|
||||
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
|
||||
|
||||
# 5.1.0 (November 12, 2022)
|
||||
|
||||
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
|
||||
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
|
||||
* Expose `GetSSLPassword` support to pgx.
|
||||
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
|
||||
* Fix date text format encoding with 5 digit years.
|
||||
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
|
||||
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
|
||||
* Add domain type support to `Conn.LoadType()`.
|
||||
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
|
||||
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
|
||||
|
||||
# 5.0.4 (October 24, 2022)
|
||||
|
||||
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
|
||||
* Fix: some reflect Kind checks to first check for nil
|
||||
* Bump golang.org/x/text dependency to placate snyk
|
||||
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
|
||||
* Fix: Exec checks if tx is closed
|
||||
|
||||
# 5.0.3 (October 14, 2022)
|
||||
|
||||
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
|
||||
|
||||
# v5.0.2 (October 8, 2022)
|
||||
|
||||
* Fix date encoding in text format to always use 2 digits for month and day
|
||||
* Prefer driver.Valuer over wrap plans when encoding
|
||||
* Fix scan to pointer to pointer to renamed type
|
||||
* Allow scanning NULL even if PG and Go types are incompatible
|
||||
|
||||
# v5.0.1 (September 24, 2022)
|
||||
|
||||
* Fix 32-bit atomic usage
|
||||
* Add MarshalJSON for Float8 (yogipristiawan)
|
||||
* Add `[` and `]` to text encoding of `Lseg`
|
||||
* Fix sqlScannerWrapper NULL handling
|
||||
|
||||
# v5.0.0 (September 17, 2022)
|
||||
|
||||
## Merged Packages
|
||||
|
||||
@ -22,9 +304,11 @@ pgconn now supports pipeline mode.
|
||||
|
||||
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
|
||||
|
||||
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
|
||||
|
||||
## pgxpool
|
||||
|
||||
`Connect` and `ConnectConfig` have been renamed to `New` and `NewConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
|
||||
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
|
||||
|
||||
## pgtype
|
||||
|
||||
@ -33,7 +317,10 @@ The `pgtype` package has been significantly changed.
|
||||
### NULL Representation
|
||||
|
||||
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
|
||||
`Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable.
|
||||
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
|
||||
|
||||
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
|
||||
whether typed or untyped now represent `NULL`.
|
||||
|
||||
### Codec and Value Split
|
||||
|
||||
@ -47,9 +334,9 @@ generally defined by implementing an interface that a particular `Codec` underst
|
||||
|
||||
### Array Types
|
||||
|
||||
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This
|
||||
significantly reduced the amount of code and the compiled binary size. This also means that less common array types such
|
||||
as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional arrays.
|
||||
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
|
||||
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
|
||||
arrays.
|
||||
|
||||
### Composite Types
|
||||
|
||||
@ -63,7 +350,7 @@ easily be handled. Multirange types are handled similarly with `MultirangeCodec`
|
||||
|
||||
### pgxtype
|
||||
|
||||
load data type moved to conn
|
||||
`LoadDataType` moved to `*Conn` as `LoadType`.
|
||||
|
||||
### Bytea
|
||||
|
||||
@ -97,7 +384,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t
|
||||
|
||||
### 3rd Party Type Integrations
|
||||
|
||||
* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to
|
||||
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
|
||||
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
|
||||
the pgx dependency tree.
|
||||
|
||||
|
121
CONTRIBUTING.md
Normal file
121
CONTRIBUTING.md
Normal file
@ -0,0 +1,121 @@
|
||||
# Contributing
|
||||
|
||||
## Discuss Significant Changes
|
||||
|
||||
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
|
||||
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
|
||||
|
||||
## Avoid Dependencies
|
||||
|
||||
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
|
||||
that adds a dependency is no.
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
|
||||
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
|
||||
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
|
||||
simplify environment variable handling.
|
||||
|
||||
### Using an Existing PostgreSQL Cluster
|
||||
|
||||
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
|
||||
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
|
||||
authentication methods).
|
||||
|
||||
Create and setup a test database:
|
||||
|
||||
```
|
||||
export PGDATABASE=pgx_test
|
||||
createdb
|
||||
psql -c 'create extension hstore;'
|
||||
psql -c 'create extension ltree;'
|
||||
psql -c 'create domain uint64 as numeric(20,0);'
|
||||
```
|
||||
|
||||
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
|
||||
such as Homebrew do not.
|
||||
|
||||
```
|
||||
createuser -s postgres
|
||||
```
|
||||
|
||||
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
|
||||
|
||||
```
|
||||
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
|
||||
go test ./...
|
||||
```
|
||||
|
||||
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
|
||||
|
||||
### Creating a New PostgreSQL Cluster Exclusively for Testing
|
||||
|
||||
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
|
||||
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
|
||||
|
||||
```
|
||||
export PGPORT=5015
|
||||
export PGUSER=postgres
|
||||
export PGDATABASE=pgx_test
|
||||
export POSTGRESQL_DATA_DIR=postgresql
|
||||
|
||||
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
|
||||
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
|
||||
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
|
||||
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
|
||||
export PGX_SSL_PASSWORD=certpw
|
||||
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
|
||||
```
|
||||
|
||||
Create a new database cluster.
|
||||
|
||||
```
|
||||
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
|
||||
|
||||
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
|
||||
|
||||
cd .testdb
|
||||
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run ../testsetup/generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
|
||||
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
|
||||
chmod 600 $POSTGRESQL_DATA_DIR/server.key
|
||||
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
|
||||
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
Start the new cluster. This will be necessary whenever you are running pgx tests.
|
||||
|
||||
```
|
||||
postgres -D .testdb/$POSTGRESQL_DATA_DIR
|
||||
```
|
||||
|
||||
Setup the test database in the new cluster.
|
||||
|
||||
```
|
||||
createdb
|
||||
psql --no-psqlrc -f testsetup/postgresql_setup.sql
|
||||
```
|
||||
|
||||
### PgBouncer
|
||||
|
||||
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
|
||||
|
||||
### Optional Tests
|
||||
|
||||
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
|
||||
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
|
||||
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
|
||||
if any tests are being skipped.
|
132
README.md
132
README.md
@ -1,15 +1,12 @@
|
||||
[](https://pkg.go.dev/github.com/jackc/pgx/v5)
|
||||
[](https://travis-ci.org/jackc/pgx)
|
||||
[](https://pkg.go.dev/github.com/jackc/pgx/v5)
|
||||
[](https://github.com/jackc/pgx/actions/workflows/ci.yml)
|
||||
|
||||
# pgx - PostgreSQL Driver and Toolkit
|
||||
|
||||
*This is the v5 development branch. It is still in active development and testing.*
|
||||
|
||||
pgx is a pure Go driver and toolkit for PostgreSQL.
|
||||
|
||||
pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for.
|
||||
|
||||
The driver component of pgx can be used alongside the standard `database/sql` package.
|
||||
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
|
||||
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
|
||||
|
||||
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
|
||||
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
|
||||
@ -51,91 +48,55 @@ func main() {
|
||||
|
||||
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
|
||||
|
||||
## Choosing Between the pgx and database/sql Interfaces
|
||||
|
||||
It is recommended to use the pgx interface if:
|
||||
1. The application only targets PostgreSQL.
|
||||
2. No other libraries that require `database/sql` are in use.
|
||||
|
||||
The pgx interface is faster and exposes more features.
|
||||
|
||||
The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`,
|
||||
`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the
|
||||
`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses.
|
||||
|
||||
## Features
|
||||
|
||||
pgx supports many features beyond what is available through `database/sql`:
|
||||
|
||||
* Support for approximately 70 different PostgreSQL types
|
||||
* Automatic statement preparation and caching
|
||||
* Batch queries
|
||||
* Single-round trip query mode
|
||||
* Full TLS connection control
|
||||
* Binary format support for custom types (allows for much quicker encoding/decoding)
|
||||
* COPY protocol support for faster bulk data loads
|
||||
* Extendable logging support
|
||||
* `COPY` protocol support for faster bulk data loads
|
||||
* Tracing and logging support
|
||||
* Connection pool with after-connect hook for arbitrary connection setup
|
||||
* Listen / notify
|
||||
* `LISTEN` / `NOTIFY`
|
||||
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
|
||||
* Hstore support
|
||||
* JSON and JSONB support
|
||||
* `hstore` support
|
||||
* `json` and `jsonb` support
|
||||
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
|
||||
* Large object support
|
||||
* NULL mapping to Null* struct or pointer to pointer
|
||||
* NULL mapping to pointer to pointer
|
||||
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
|
||||
* Notice response handling
|
||||
* Simulated nested transactions with savepoints
|
||||
|
||||
## Performance
|
||||
## Choosing Between the pgx and database/sql Interfaces
|
||||
|
||||
There are three areas in particular where pgx can provide a significant performance advantage over the standard
|
||||
`database/sql` interface and other drivers:
|
||||
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
|
||||
through the `database/sql` interface.
|
||||
|
||||
1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format.
|
||||
2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an
|
||||
significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can
|
||||
perform nearly 3x the number of queries per second.
|
||||
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
|
||||
The pgx interface is recommended when:
|
||||
|
||||
1. The application only targets PostgreSQL.
|
||||
2. No other libraries that require `database/sql` are in use.
|
||||
|
||||
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
|
||||
|
||||
## Testing
|
||||
|
||||
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment
|
||||
variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment
|
||||
variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable
|
||||
handling.
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
|
||||
|
||||
### Example Test Environment
|
||||
## Architecture
|
||||
|
||||
Connect to your PostgreSQL server and run:
|
||||
|
||||
```
|
||||
create database pgx_test;
|
||||
```
|
||||
|
||||
Connect to the newly-created database and run:
|
||||
|
||||
```
|
||||
create domain uint64 as numeric(20,0);
|
||||
```
|
||||
|
||||
Now, you can run the tests:
|
||||
|
||||
```
|
||||
PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./...
|
||||
```
|
||||
|
||||
In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
|
||||
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
|
||||
|
||||
## Supported Go and PostgreSQL Versions
|
||||
|
||||
~~pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).~~
|
||||
|
||||
`v5` is targeted at Go 1.18+. The general release of `v5` is not planned until second half of 2022 so it is expected that the policy of supporting the two most recent versions of Go will be maintained or restored soon after its release.
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
|
||||
## Version Policy
|
||||
|
||||
pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version.
|
||||
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
|
||||
|
||||
## PGX Family Libraries
|
||||
|
||||
@ -159,8 +120,14 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
||||
|
||||
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
|
||||
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||
|
||||
|
||||
## Adapters for 3rd Party Tracers
|
||||
|
||||
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
||||
|
||||
## Adapters for 3rd Party Loggers
|
||||
|
||||
These adapters can be used with the tracelog package.
|
||||
@ -170,13 +137,50 @@ These adapters can be used with the tracelog package.
|
||||
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
|
||||
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
|
||||
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
|
||||
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
|
||||
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
|
||||
|
||||
## 3rd Party Libraries with PGX Support
|
||||
|
||||
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
|
||||
|
||||
pgxmock is a mock library implementing pgx interfaces.
|
||||
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
|
||||
|
||||
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
|
||||
|
||||
Library for scanning data from a database into Go structs and more.
|
||||
|
||||
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
|
||||
|
||||
A carefully designed SQL client for making using SQL easier,
|
||||
more productive, and less error-prone on Golang.
|
||||
|
||||
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
|
||||
Adds GSSAPI / Kerberos authentication support.
|
||||
|
||||
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
|
||||
|
||||
Explicit data mapping and scanning library for Go structs and slices.
|
||||
|
||||
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
|
||||
|
||||
Type safe and flexible package for scanning database data into Go types.
|
||||
Supports, structs, maps, slices and custom mapping functions.
|
||||
|
||||
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||
|
||||
Code first migration library for native pgx (no database/sql abstraction).
|
||||
|
||||
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
|
||||
|
||||
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
|
||||
|
||||
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
|
||||
|
||||
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
|
||||
|
||||
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
||||
|
||||
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
||||
|
2
Rakefile
2
Rakefile
@ -2,7 +2,7 @@ require "erb"
|
||||
|
||||
rule '.go' => '.go.erb' do |task|
|
||||
erb = ERB.new(File.read(task.source))
|
||||
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
||||
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
|
||||
sh "goimports", "-w", task.name
|
||||
end
|
||||
|
||||
|
110
batch.go
110
batch.go
@ -10,9 +10,9 @@ import (
|
||||
|
||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||
type QueuedQuery struct {
|
||||
query string
|
||||
arguments []any
|
||||
fn batchItemFunc
|
||||
SQL string
|
||||
Arguments []any
|
||||
Fn batchItemFunc
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
|
||||
@ -20,14 +20,11 @@ type batchItemFunc func(br BatchResults) error
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
rows, _ := br.Query()
|
||||
defer rows.Close()
|
||||
|
||||
err = fn(rows)
|
||||
err := fn(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -39,7 +36,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
row := br.QueryRow()
|
||||
return fn(row)
|
||||
}
|
||||
@ -47,7 +44,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||
|
||||
// Exec sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -60,22 +57,28 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
type Batch struct {
|
||||
queuedQueries []*QueuedQuery
|
||||
QueuedQueries []*QueuedQuery
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
|
||||
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
|
||||
//
|
||||
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
|
||||
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
|
||||
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
|
||||
// include the current query may reference the wrong query.
|
||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||
qq := &QueuedQuery{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
SQL: query,
|
||||
Arguments: arguments,
|
||||
}
|
||||
b.queuedQueries = append(b.queuedQueries, qq)
|
||||
b.QueuedQueries = append(b.QueuedQueries, qq)
|
||||
return qq
|
||||
}
|
||||
|
||||
// Len returns number of queries that have been queued so far.
|
||||
func (b *Batch) Len() int {
|
||||
return len(b.queuedQueries)
|
||||
return len(b.QueuedQueries)
|
||||
}
|
||||
|
||||
type BatchResults interface {
|
||||
@ -129,7 +132,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if !br.mrr.NextResult() {
|
||||
err := br.mrr.Close()
|
||||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
err = errors.New("no more results in batch")
|
||||
}
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
@ -142,7 +145,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||
}
|
||||
|
||||
commandTag, err := br.mrr.ResultReader().Close()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
br.mrr.Close()
|
||||
}
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
@ -178,7 +184,7 @@ func (br *batchResults) Query() (Rows, error) {
|
||||
if !br.mrr.NextResult() {
|
||||
rows.err = br.mrr.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = errors.New("no result")
|
||||
rows.err = errors.New("no more results in batch")
|
||||
}
|
||||
rows.closed = true
|
||||
|
||||
@ -225,10 +231,10 @@ func (br *batchResults) Close() error {
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
@ -251,10 +257,10 @@ func (br *batchResults) earlyError() error {
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
query = bi.SQL
|
||||
args = bi.Arguments
|
||||
ok = true
|
||||
br.qqIdx++
|
||||
}
|
||||
@ -285,12 +291,15 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
|
||||
query, arguments, _ := br.nextQueryAndArgs()
|
||||
query, arguments, err := br.nextQueryAndArgs()
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
return pgconn.CommandTag{}, err
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
var commandTag pgconn.CommandTag
|
||||
switch results := results.(type) {
|
||||
@ -309,7 +318,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
return commandTag, br.err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
@ -328,9 +337,9 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
query, arguments, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
query = "batch query"
|
||||
query, arguments, err := br.nextQueryAndArgs()
|
||||
if err != nil {
|
||||
return &baseRows{err: err, closed: true}, err
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
@ -384,24 +393,20 @@ func (br *pipelineBatchResults) Close() error {
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
return br.err
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
@ -423,13 +428,16 @@ func (br *pipelineBatchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) {
|
||||
if br.b == nil {
|
||||
return "", nil, errors.New("no reference to batch")
|
||||
}
|
||||
|
||||
if br.qqIdx >= len(br.b.QueuedQueries) {
|
||||
return "", nil, errors.New("no more results in batch")
|
||||
}
|
||||
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
return bi.SQL, bi.Arguments, nil
|
||||
}
|
||||
|
263
batch_test.go
263
batch_test.go
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
@ -17,7 +18,10 @@ import (
|
||||
func TestConnSendBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
@ -36,7 +40,7 @@ func TestConnSendBatch(t *testing.T) {
|
||||
batch.Queue("select * from ledger where false")
|
||||
batch.Queue("select sum(amount) from ledger")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
@ -152,7 +156,10 @@ func TestConnSendBatch(t *testing.T) {
|
||||
func TestConnSendBatchQueuedQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
@ -237,7 +244,7 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
err := conn.SendBatch(context.Background(), batch).Close()
|
||||
err := conn.SendBatch(ctx, batch).Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@ -245,7 +252,10 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
|
||||
func TestConnSendBatchMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
@ -262,7 +272,7 @@ func TestConnSendBatchMany(t *testing.T) {
|
||||
}
|
||||
batch.Queue("select count(*) from ledger")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
for i := 0; i < numInserts; i++ {
|
||||
ct, err := br.Exec()
|
||||
@ -280,6 +290,45 @@ func TestConnSendBatchMany(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178
|
||||
func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
commandTag, err := br.Exec()
|
||||
require.Equal(t, "", commandTag.String())
|
||||
require.EqualError(t, err, "no more results in batch")
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1")
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
commandTag, err := br.Exec()
|
||||
require.Equal(t, "SELECT 1", commandTag.String())
|
||||
require.NoError(t, err)
|
||||
commandTag, err = br.Exec()
|
||||
require.Equal(t, "", commandTag.String())
|
||||
require.EqualError(t, err, "no more results in batch")
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -290,9 +339,12 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
pgx.QueryExecModeExec,
|
||||
// Don't test simple mode with prepared statements.
|
||||
}
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
||||
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
|
||||
_, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -304,7 +356,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
batch.Queue("ps1", 5)
|
||||
}
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
for i := 0; i < queryCount; i++ {
|
||||
rows, err := br.Query()
|
||||
@ -337,13 +389,16 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
|
||||
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
|
||||
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
var n int32
|
||||
err := br.QueryRow().Scan(&n)
|
||||
@ -368,6 +423,9 @@ func TestConnSendBatchWithQueryRewriter(t *testing.T) {
|
||||
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -380,7 +438,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
||||
|
||||
_, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
|
||||
_, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -392,7 +450,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
|
||||
batch.Queue("ps1", 5)
|
||||
}
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
for i := 0; i < queryCount; i++ {
|
||||
rows, err := br.Query()
|
||||
@ -426,13 +484,16 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
|
||||
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select n from generate_series(0,5) n")
|
||||
batch.Queue("select n from generate_series(0,5) n")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
@ -485,13 +546,16 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
func TestConnSendBatchQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
||||
batch.Queue("select n from generate_series(0,5) n")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
@ -523,12 +587,15 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
||||
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1 1")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
var n int32
|
||||
err := br.QueryRow().Scan(&n)
|
||||
@ -547,7 +614,10 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
||||
func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
@ -560,7 +630,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||
batch.Queue("select 1")
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
var value int
|
||||
err := br.QueryRow().Scan(&value)
|
||||
@ -584,7 +654,10 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
@ -597,7 +670,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||
batch.Queue("select 1 union all select 2 union all select 3")
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
@ -621,7 +694,10 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||
func TestTxSendBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
@ -635,7 +711,7 @@ func TestTxSendBatch(t *testing.T) {
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin(context.Background())
|
||||
tx, _ := conn.Begin(ctx)
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
||||
|
||||
@ -652,7 +728,7 @@ func TestTxSendBatch(t *testing.T) {
|
||||
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
|
||||
batch.Queue("select amount from ledger2 where id = $1", id)
|
||||
|
||||
br = tx.SendBatch(context.Background(), batch)
|
||||
br = tx.SendBatch(ctx, batch)
|
||||
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
@ -669,10 +745,10 @@ func TestTxSendBatch(t *testing.T) {
|
||||
}
|
||||
|
||||
br.Close()
|
||||
tx.Commit(context.Background())
|
||||
tx.Commit(ctx)
|
||||
|
||||
var count int
|
||||
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
|
||||
conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("count => %v, want %v", count, 1)
|
||||
}
|
||||
@ -688,7 +764,10 @@ func TestTxSendBatch(t *testing.T) {
|
||||
func TestTxSendBatchRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
@ -696,11 +775,11 @@ func TestTxSendBatchRollback(t *testing.T) {
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin(context.Background())
|
||||
tx, _ := conn.Begin(ctx)
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
||||
|
||||
br := tx.SendBatch(context.Background(), batch)
|
||||
br := tx.SendBatch(ctx, batch)
|
||||
|
||||
var id int
|
||||
err := br.QueryRow().Scan(&id)
|
||||
@ -708,9 +787,9 @@ func TestTxSendBatchRollback(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
br.Close()
|
||||
tx.Rollback(context.Background())
|
||||
tx.Rollback(ctx)
|
||||
|
||||
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
|
||||
row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
|
||||
var count int
|
||||
row.Scan(&count)
|
||||
if count != 0 {
|
||||
@ -720,10 +799,62 @@ func TestTxSendBatchRollback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1578
|
||||
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 4 / $1::int", 0)
|
||||
|
||||
batchResult := conn.SendBatch(ctx, batch)
|
||||
|
||||
_, execErr := batchResult.Exec()
|
||||
require.Error(t, execErr)
|
||||
|
||||
closeErr := batchResult.Close()
|
||||
require.Equal(t, execErr, closeErr)
|
||||
|
||||
// Try to use the connection.
|
||||
_, err := conn.Exec(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 4 / n from generate_series(-2, 2) n")
|
||||
|
||||
batchResult := conn.SendBatch(ctx, batch)
|
||||
|
||||
_, execErr := batchResult.Exec()
|
||||
require.Error(t, execErr)
|
||||
|
||||
closeErr := batchResult.Close()
|
||||
require.Equal(t, execErr, closeErr)
|
||||
|
||||
// Try to use the connection.
|
||||
_, err := conn.Exec(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
||||
|
||||
@ -739,7 +870,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
|
||||
batch.Queue(`update t set n=n+1 where id='b' returning *`)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
@ -768,6 +899,9 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConnSendBatchNoStatementCache(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||
config.StatementCacheCapacity = 0
|
||||
@ -776,10 +910,13 @@ func TestConnSendBatchNoStatementCache(t *testing.T) {
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
testConnSendBatch(t, conn, 3)
|
||||
testConnSendBatch(t, ctx, conn, 3)
|
||||
}
|
||||
|
||||
func TestConnSendBatchPrepareStatementCache(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 32
|
||||
@ -787,10 +924,13 @@ func TestConnSendBatchPrepareStatementCache(t *testing.T) {
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
testConnSendBatch(t, conn, 3)
|
||||
testConnSendBatch(t, ctx, conn, 3)
|
||||
}
|
||||
|
||||
func TestConnSendBatchDescribeStatementCache(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.DescriptionCacheCapacity = 32
|
||||
@ -798,16 +938,16 @@ func TestConnSendBatchDescribeStatementCache(t *testing.T) {
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
testConnSendBatch(t, conn, 3)
|
||||
testConnSendBatch(t, ctx, conn, 3)
|
||||
}
|
||||
|
||||
func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
|
||||
func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
|
||||
batch := &pgx.Batch{}
|
||||
for j := 0; j < queryCount; j++ {
|
||||
batch.Queue("select n from generate_series(0,5) n")
|
||||
}
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
br := conn.SendBatch(ctx, batch)
|
||||
|
||||
for j := 0; j < queryCount; j++ {
|
||||
rows, err := br.Query()
|
||||
@ -830,12 +970,12 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
|
||||
func TestSendBatchSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -868,8 +1008,41 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
|
||||
assert.False(t, rows.Next())
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
|
||||
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select col1 from foo")
|
||||
batch.Queue("select col1 from baz")
|
||||
err := conn.SendBatch(ctx, batch).Close()
|
||||
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
|
||||
|
||||
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
|
||||
|
||||
// Since table baz now exists, the batch should succeed.
|
||||
|
||||
batch = &pgx.Batch{}
|
||||
batch.Queue("select col1 from foo")
|
||||
batch.Queue("select col1 from baz")
|
||||
err = conn.SendBatch(ctx, batch).Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func ExampleConn_SendBatch() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
@ -912,7 +1085,7 @@ func ExampleConn_SendBatch() {
|
||||
return err
|
||||
})
|
||||
|
||||
err = conn.SendBatch(context.Background(), batch).Close()
|
||||
err = conn.SendBatch(ctx, batch).Close()
|
||||
if err != nil {
|
||||
fmt.Printf("SendBatch error: %v", err)
|
||||
return
|
||||
|
201
bench_test.go
201
bench_test.go
@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -152,7 +151,7 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) {
|
||||
|
||||
for rr.NextRow() {
|
||||
for i := range rr.Values() {
|
||||
if bytes.Compare(rr.Values()[0], encodedBytes) != 0 {
|
||||
if !bytes.Equal(rr.Values()[0], encodedBytes) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes)
|
||||
}
|
||||
}
|
||||
@ -340,8 +339,9 @@ type benchmarkWriteTableCopyFromSrc struct {
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
||||
next := s.idx < s.count
|
||||
s.idx++
|
||||
return s.idx < s.count
|
||||
return next
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) {
|
||||
@ -407,6 +407,34 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkWriteNRowsViaBatchInsert(b *testing.B, n int) {
|
||||
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
mustExec(b, conn, benchmarkWriteTableCreateSQL)
|
||||
_, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
src := newBenchmarkWriteTableCopyFromSrc(n)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
for src.Next() {
|
||||
values, _ := src.Values()
|
||||
batch.Queue("insert_t", values...)
|
||||
}
|
||||
|
||||
err = conn.SendBatch(context.Background(), batch).Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type queryArgs []any
|
||||
|
||||
func (qa *queryArgs) Append(v any) string {
|
||||
@ -484,7 +512,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
||||
}
|
||||
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
return 0, nil
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return rowCount, nil
|
||||
@ -560,6 +588,22 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite2RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 2)
|
||||
}
|
||||
|
||||
func BenchmarkWrite2RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 2)
|
||||
}
|
||||
|
||||
func BenchmarkWrite2RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 2)
|
||||
}
|
||||
|
||||
func BenchmarkWrite2RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 2)
|
||||
}
|
||||
|
||||
func BenchmarkWrite5RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 5)
|
||||
}
|
||||
@ -567,6 +611,9 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) {
|
||||
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 5)
|
||||
}
|
||||
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 5)
|
||||
}
|
||||
|
||||
func BenchmarkWrite5RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 5)
|
||||
@ -579,6 +626,9 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) {
|
||||
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 10)
|
||||
}
|
||||
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 10)
|
||||
@ -591,6 +641,9 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) {
|
||||
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 100)
|
||||
}
|
||||
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkWrite100RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 100)
|
||||
@ -604,6 +657,10 @@ func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 1000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite1000RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 1000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 1000)
|
||||
}
|
||||
@ -615,6 +672,9 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
|
||||
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 10000)
|
||||
}
|
||||
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaBatchInsert(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 10000)
|
||||
@ -884,6 +944,7 @@ type BenchRowSimple struct {
|
||||
BirthDate time.Time
|
||||
Weight int32
|
||||
Height int32
|
||||
Tags []string
|
||||
UpdateTime time.Time
|
||||
}
|
||||
|
||||
@ -897,13 +958,13 @@ func BenchmarkSelectRowsScanSimple(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
br := &BenchRowSimple{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
@ -922,6 +983,7 @@ type BenchRowStringBytes struct {
|
||||
BirthDate time.Time
|
||||
Weight int32
|
||||
Height int32
|
||||
Tags []string
|
||||
UpdateTime time.Time
|
||||
}
|
||||
|
||||
@ -935,13 +997,13 @@ func BenchmarkSelectRowsScanStringBytes(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
br := &BenchRowStringBytes{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
@ -960,6 +1022,7 @@ type BenchRowDecoder struct {
|
||||
BirthDate pgtype.Date
|
||||
Weight pgtype.Int4
|
||||
Height pgtype.Int4
|
||||
Tags pgtype.FlatArray[string]
|
||||
UpdateTime pgtype.Timestamptz
|
||||
}
|
||||
|
||||
@ -985,7 +1048,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query(
|
||||
context.Background(),
|
||||
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
|
||||
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
|
||||
pgx.QueryResultFormats{format.code},
|
||||
rowCount,
|
||||
)
|
||||
@ -994,7 +1057,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
|
||||
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
@ -1016,7 +1079,7 @@ func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
|
||||
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
|
||||
for mrr.NextResult() {
|
||||
rr := mrr.ResultReader()
|
||||
for rr.NextRow() {
|
||||
@ -1053,11 +1116,11 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := conn.PgConn().ExecParams(
|
||||
context.Background(),
|
||||
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
|
||||
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
|
||||
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
|
||||
nil,
|
||||
nil,
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
|
||||
)
|
||||
for rr.NextRow() {
|
||||
rr.Values()
|
||||
@ -1074,13 +1137,107 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if len(benchRows) != int(rowCount) {
|
||||
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
benchRows := make([]BenchRowSimple, 0, rowCount)
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchRows = benchRows[:0]
|
||||
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
var err error
|
||||
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if len(benchRows) != int(rowCount) {
|
||||
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if len(benchRows) != int(rowCount) {
|
||||
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
benchRows := make([]BenchRowSimple, 0, rowCount)
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchRows = benchRows[:0]
|
||||
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
var err error
|
||||
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if len(benchRows) != int(rowCount) {
|
||||
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
|
||||
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@ -1102,7 +1259,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
||||
"ps1",
|
||||
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
|
||||
nil,
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
|
||||
)
|
||||
for rr.NextRow() {
|
||||
rr.Values()
|
||||
@ -1120,7 +1277,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
||||
}
|
||||
|
||||
type queryRecorder struct {
|
||||
conn nbconn.Conn
|
||||
conn net.Conn
|
||||
writeBuf []byte
|
||||
readCount int
|
||||
}
|
||||
@ -1136,14 +1293,6 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
|
||||
return qr.conn.Write(b)
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) BufferReadUntilBlock() error {
|
||||
return qr.conn.BufferReadUntilBlock()
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) Flush() error {
|
||||
return qr.conn.Flush()
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) Close() error {
|
||||
return qr.conn.Close()
|
||||
}
|
||||
@ -1189,7 +1338,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn()
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
|
||||
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@ -1212,7 +1361,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
|
||||
"ps1",
|
||||
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
|
||||
nil,
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
|
||||
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
|
||||
)
|
||||
_, err := rr.Close()
|
||||
require.NoError(b, err)
|
||||
|
@ -9,40 +9,41 @@ then
|
||||
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
|
||||
sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
|
||||
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
fi
|
||||
|
||||
sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
|
||||
sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
|
||||
|
||||
cd testsetup
|
||||
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||
sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
|
||||
cp ca.pem /tmp
|
||||
cp pgx_sslcert.key /tmp
|
||||
cp pgx_sslcert.crt /tmp
|
||||
|
||||
cd ..
|
||||
|
||||
sudo /etc/init.d/postgresql restart
|
||||
|
||||
psql -U postgres -c 'create database pgx_test'
|
||||
psql -U postgres pgx_test -c 'create extension hstore'
|
||||
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
|
||||
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user `whoami`"
|
||||
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
|
||||
|
||||
# The tricky test user, below, has to actually exist so that it can be used in a test
|
||||
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
|
||||
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
|
||||
createdb -U postgres pgx_test
|
||||
psql -U postgres -f testsetup/postgresql_setup.sql pgx_test
|
||||
fi
|
||||
|
||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||
then
|
||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz
|
||||
sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/
|
||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
|
||||
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
|
||||
cockroach start-single-node --insecure --background --listen-addr=localhost
|
||||
cockroach sql --insecure -e 'create database pgx_test'
|
||||
fi
|
||||
|
528
conn.go
528
conn.go
@ -2,13 +2,15 @@ package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
@ -35,13 +37,18 @@ type ConnConfig struct {
|
||||
|
||||
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
|
||||
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
|
||||
// PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
|
||||
// PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
|
||||
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
|
||||
DefaultQueryExecMode QueryExecMode
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||
type ParseConfigOptions struct {
|
||||
pgconn.ParseConfigOptions
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||
// The only exception is the tls.Config:
|
||||
// according to the tls.Config docs it must not be modified after creation.
|
||||
@ -94,11 +101,33 @@ func (ident Identifier) Sanitize() string {
|
||||
return strings.Join(parts, ".")
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNoRows occurs when rows are expected but none are returned.
|
||||
var ErrNoRows = errors.New("no rows in result set")
|
||||
ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set")
|
||||
// ErrTooManyRows occurs when more rows than expected are returned.
|
||||
ErrTooManyRows = errors.New("too many rows in result set")
|
||||
)
|
||||
|
||||
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
func newProxyErr(background error, msg string) error {
|
||||
return &proxyError{
|
||||
msg: msg,
|
||||
background: background,
|
||||
}
|
||||
}
|
||||
|
||||
type proxyError struct {
|
||||
msg string
|
||||
background error
|
||||
}
|
||||
|
||||
func (err *proxyError) Error() string { return err.msg }
|
||||
|
||||
func (err *proxyError) Unwrap() error { return err.background }
|
||||
|
||||
var (
|
||||
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
)
|
||||
|
||||
// Connect establishes a connection with a PostgreSQL server with a connection string. See
|
||||
// pgconn.Connect for details.
|
||||
@ -110,6 +139,16 @@ func Connect(ctx context.Context, connString string) (*Conn, error) {
|
||||
return connect(ctx, connConfig)
|
||||
}
|
||||
|
||||
// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
|
||||
// provide a GetSSLPassword function.
|
||||
func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
|
||||
connConfig, err := ParseConfigWithOptions(connString, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return connect(ctx, connConfig)
|
||||
}
|
||||
|
||||
// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
|
||||
// connConfig must have been created by ParseConfig.
|
||||
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
|
||||
@ -120,22 +159,10 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
|
||||
return connect(ctx, connConfig)
|
||||
}
|
||||
|
||||
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
|
||||
// does. In addition, it accepts the following options:
|
||||
//
|
||||
// default_query_exec_mode
|
||||
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
|
||||
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
|
||||
//
|
||||
// statement_cache_capacity
|
||||
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
|
||||
// Default: 512.
|
||||
//
|
||||
// description_cache_capacity
|
||||
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
|
||||
// Default: 512.
|
||||
func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
config, err := pgconn.ParseConfig(connString)
|
||||
// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
|
||||
// only used to provide a GetSSLPassword function.
|
||||
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
|
||||
config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -145,7 +172,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
delete(config.RuntimeParams, "statement_cache_capacity")
|
||||
n, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
|
||||
return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err)
|
||||
}
|
||||
statementCacheCapacity = int(n)
|
||||
}
|
||||
@ -155,7 +182,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
delete(config.RuntimeParams, "description_cache_capacity")
|
||||
n, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
|
||||
return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err)
|
||||
}
|
||||
descriptionCacheCapacity = int(n)
|
||||
}
|
||||
@ -175,7 +202,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
case "simple_protocol":
|
||||
defaultQueryExecMode = QueryExecModeSimpleProtocol
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err)
|
||||
return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,6 +218,24 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
return connConfig, nil
|
||||
}
|
||||
|
||||
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig]
|
||||
// does. In addition, it accepts the following options:
|
||||
//
|
||||
// - default_query_exec_mode.
|
||||
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
|
||||
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
|
||||
//
|
||||
// - statement_cache_capacity.
|
||||
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
|
||||
// Default: 512.
|
||||
//
|
||||
// - description_cache_capacity.
|
||||
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
|
||||
// Default: 512.
|
||||
func ParseConfig(connString string) (*ConnConfig, error) {
|
||||
return ParseConfigWithOptions(connString, ParseConfigOptions{})
|
||||
}
|
||||
|
||||
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
|
||||
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
|
||||
@ -248,7 +293,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes a connection. It is safe to call Close on a already closed
|
||||
// Close closes a connection. It is safe to call Close on an already closed
|
||||
// connection.
|
||||
func (c *Conn) Close(ctx context.Context) error {
|
||||
if c.IsClosed() {
|
||||
@ -259,12 +304,15 @@ func (c *Conn) Close(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
||||
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
|
||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
|
||||
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
|
||||
// Exec to execute the statement. It can also be used with Batch.Queue.
|
||||
//
|
||||
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
|
||||
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
||||
// concern for if the statement has already been prepared.
|
||||
// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
|
||||
// name == sql.
|
||||
//
|
||||
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
|
||||
// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
|
||||
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
|
||||
if c.prepareTracer != nil {
|
||||
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
|
||||
@ -286,22 +334,60 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
|
||||
}()
|
||||
}
|
||||
|
||||
sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
|
||||
var psName, psKey string
|
||||
if name == sql {
|
||||
digest := sha256.Sum256([]byte(sql))
|
||||
psName = "stmt_" + hex.EncodeToString(digest[0:24])
|
||||
psKey = sql
|
||||
} else {
|
||||
psName = name
|
||||
psKey = name
|
||||
}
|
||||
|
||||
sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
c.preparedStatements[name] = sd
|
||||
if psKey != "" {
|
||||
c.preparedStatements[psKey] = sd
|
||||
}
|
||||
|
||||
return sd, nil
|
||||
}
|
||||
|
||||
// Deallocate released a prepared statement
|
||||
// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
|
||||
func (c *Conn) Deallocate(ctx context.Context, name string) error {
|
||||
var psName string
|
||||
sd := c.preparedStatements[name]
|
||||
if sd != nil {
|
||||
psName = sd.Name
|
||||
} else {
|
||||
psName = name
|
||||
}
|
||||
|
||||
err := c.pgConn.Deallocate(ctx, psName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sd != nil {
|
||||
delete(c.preparedStatements, name)
|
||||
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
|
||||
func (c *Conn) DeallocateAll(ctx context.Context) error {
|
||||
c.preparedStatements = map[string]*pgconn.StatementDescription{}
|
||||
if c.config.StatementCacheCapacity > 0 {
|
||||
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
||||
}
|
||||
if c.config.DescriptionCacheCapacity > 0 {
|
||||
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
||||
}
|
||||
_, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
@ -334,7 +420,7 @@ func (c *Conn) IsClosed() bool {
|
||||
return c.pgConn.IsClosed()
|
||||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
func (c *Conn) die() {
|
||||
if c.IsClosed() {
|
||||
return
|
||||
}
|
||||
@ -348,11 +434,9 @@ func quoteIdentifier(s string) string {
|
||||
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
// Ping executes an empty sql statement against the *Conn
|
||||
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
|
||||
// Ping delegates to the underlying *pgconn.PgConn.Ping.
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
_, err := c.Exec(ctx, ";")
|
||||
return err
|
||||
return c.pgConn.Ping(ctx)
|
||||
}
|
||||
|
||||
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
|
||||
@ -407,7 +491,10 @@ optionLoop:
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Always use simple protocol when there are no arguments.
|
||||
@ -426,7 +513,7 @@ optionLoop:
|
||||
}
|
||||
sd := c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
@ -444,6 +531,7 @@ optionLoop:
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
|
||||
return c.execParams(ctx, sd, arguments)
|
||||
@ -472,7 +560,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a
|
||||
|
||||
mrr := c.pgConn.Exec(ctx, sql)
|
||||
for mrr.NextResult() {
|
||||
commandTag, err = mrr.ResultReader().Close()
|
||||
commandTag, _ = mrr.ResultReader().Close()
|
||||
}
|
||||
err = mrr.Close()
|
||||
return commandTag, err
|
||||
@ -500,14 +588,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
||||
return result.CommandTag, result.Err
|
||||
}
|
||||
|
||||
type unknownArgumentTypeQueryExecModeExecError struct {
|
||||
arg any
|
||||
}
|
||||
|
||||
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
||||
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
||||
}
|
||||
|
||||
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
||||
err := c.eqb.Build(c.typeMap, nil, args)
|
||||
if err != nil {
|
||||
@ -538,40 +618,57 @@ type QueryExecMode int32
|
||||
const (
|
||||
_ QueryExecMode = iota
|
||||
|
||||
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single
|
||||
// round trip after the statement is cached. This is the default.
|
||||
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round
|
||||
// trip after the statement is cached. This is the default. If the database schema is modified or the search_path is
|
||||
// changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the
|
||||
// number of columns returned by a "SELECT *" changes or the type of a column is changed.
|
||||
QueryExecModeCacheStatement
|
||||
|
||||
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the
|
||||
// extended protocol. Queries are executed in a single round trip after the description is cached. If the database
|
||||
// schema is modified or the search_path is changed this may result in undetected result decoding errors.
|
||||
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended
|
||||
// protocol. Queries are executed in a single round trip after the description is cached. If the database schema is
|
||||
// modified or the search_path is changed after a statement is cached then the first execution of a previously cached
|
||||
// query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed.
|
||||
QueryExecModeCacheDescribe
|
||||
|
||||
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
|
||||
// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even
|
||||
// when the the database schema is modified concurrently.
|
||||
// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
|
||||
// statement description on the first round trip and then uses it to execute the query on the second round trip. This
|
||||
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
|
||||
// even when the database schema is modified concurrently.
|
||||
QueryExecModeDescribeExec
|
||||
|
||||
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
|
||||
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
|
||||
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
|
||||
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
|
||||
// unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
|
||||
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
|
||||
//
|
||||
// On rare occasions user defined types may behave differently when encoded in the text format instead of the binary
|
||||
// format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as
|
||||
// Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7.
|
||||
// But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it
|
||||
// is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query.
|
||||
// This should not occur with types pgx supports directly and can be avoided by registering the types with
|
||||
// pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it
|
||||
// should implement pgtype.Int64Valuer.
|
||||
QueryExecModeExec
|
||||
|
||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
|
||||
// Queries are executed in a single round trip. Type mappings can be registered with
|
||||
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
|
||||
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
|
||||
// a map[string]string directly as an argument. This mode cannot.
|
||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
|
||||
// especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
|
||||
// instead for text type values including json and jsonb. Type mappings can be registered with
|
||||
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
|
||||
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
|
||||
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
|
||||
//
|
||||
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
|
||||
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
|
||||
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
|
||||
// the warning regarding differences in text format and binary format encoding with user defined types. There may be
|
||||
// other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single
|
||||
// string.
|
||||
//
|
||||
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
|
||||
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
|
||||
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
|
||||
// not support the extended protocol.
|
||||
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should
|
||||
// only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not
|
||||
// support the extended protocol.
|
||||
QueryExecModeSimpleProtocol
|
||||
)
|
||||
|
||||
@ -600,7 +697,7 @@ type QueryResultFormatsByOID map[uint32]int16
|
||||
|
||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||
type QueryRewriter interface {
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
|
||||
}
|
||||
|
||||
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
|
||||
@ -611,6 +708,9 @@ type QueryRewriter interface {
|
||||
// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
|
||||
// is allowed to ignore the error returned from Query and handle it in Rows.
|
||||
//
|
||||
// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not
|
||||
// return an error.
|
||||
//
|
||||
// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
|
||||
// collected before processing rather than processed while receiving each row. This avoids the possibility of the
|
||||
// application processing rows from a query that the server rejected. The CollectRows function is useful here.
|
||||
@ -659,7 +759,16 @@ optionLoop:
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
var err error
|
||||
originalSQL := sql
|
||||
originalArgs := args
|
||||
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
if err != nil {
|
||||
rows := c.getRows(ctx, originalSQL, originalArgs)
|
||||
err = fmt.Errorf("rewrite query failed: %w", err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
// Bypass any statement caching.
|
||||
@ -668,51 +777,17 @@ optionLoop:
|
||||
}
|
||||
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
|
||||
var err error
|
||||
sd, explicitPreparedStatement := c.preparedStatements[sql]
|
||||
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
|
||||
if sd == nil {
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
err = errDisabledStatementCache
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
sd, err = c.getStatementDescription(ctx, mode, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
err = errDisabledDescriptionCache
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
@ -781,6 +856,47 @@ optionLoop:
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
// getStatementDescription returns the statement description of the sql query
|
||||
// according to the given mode.
|
||||
//
|
||||
// If the mode is one that doesn't require to know the param and result OIDs
|
||||
// then nil is returned without error.
|
||||
func (c *Conn) getStatementDescription(
|
||||
ctx context.Context,
|
||||
mode QueryExecMode,
|
||||
sql string,
|
||||
) (sd *pgconn.StatementDescription, err error) {
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
return nil, errDisabledStatementCache
|
||||
}
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
return nil, errDisabledDescriptionCache
|
||||
}
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
return c.Prepare(ctx, "", sql)
|
||||
}
|
||||
return sd, err
|
||||
}
|
||||
|
||||
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
||||
// querying is deferred until calling Scan on the returned Row. That Row will
|
||||
// error with ErrNoRows if no rows are returned.
|
||||
@ -792,6 +908,9 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
||||
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
|
||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||
// is used again.
|
||||
//
|
||||
// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table
|
||||
// and using it in a subsequent query in the same batch can fail.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
if c.batchTracer != nil {
|
||||
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
|
||||
@ -807,15 +926,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
var queryRewriter QueryRewriter
|
||||
sql := bi.query
|
||||
arguments := bi.arguments
|
||||
sql := bi.SQL
|
||||
arguments := bi.Arguments
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
// Update Batch.Queue function comment when additional options are implemented
|
||||
switch arg := arguments[0].(type) {
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
@ -826,20 +944,26 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
var err error
|
||||
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
|
||||
}
|
||||
}
|
||||
|
||||
bi.query = sql
|
||||
bi.arguments = arguments
|
||||
bi.SQL = sql
|
||||
bi.Arguments = arguments
|
||||
}
|
||||
|
||||
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
|
||||
}
|
||||
|
||||
// All other modes use extended protocol and thus can use prepared statements.
|
||||
for _, bi := range b.queuedQueries {
|
||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if sd, ok := c.preparedStatements[bi.SQL]; ok {
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
@ -860,11 +984,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.queuedQueries {
|
||||
for i, bi := range b.QueuedQueries {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
@ -883,21 +1007,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
|
||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
sd := bi.sd
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
@ -916,24 +1040,24 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.statementCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.statementCache.Get(bi.query)
|
||||
sd := c.statementCache.Get(bi.SQL)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
Name: stmtcache.NextStatementName(),
|
||||
SQL: bi.query,
|
||||
Name: stmtcache.StatementName(bi.SQL),
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@ -948,23 +1072,23 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.descriptionCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.descriptionCache.Get(bi.query)
|
||||
sd := c.descriptionCache.Get(bi.SQL)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@ -981,13 +1105,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd := &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@ -1000,33 +1124,51 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
|
||||
pipeline := c.pgConn.StartPipeline(context.Background())
|
||||
pipeline := c.pgConn.StartPipeline(ctx)
|
||||
defer func() {
|
||||
if pbr.err != nil {
|
||||
if pbr != nil && pbr.err != nil {
|
||||
pipeline.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare any needed queries
|
||||
if len(distinctNewQueries) > 0 {
|
||||
err := func() (err error) {
|
||||
for _, sd := range distinctNewQueries {
|
||||
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
// Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
|
||||
// clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Put(sd)
|
||||
}
|
||||
}
|
||||
|
||||
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
|
||||
defer func() {
|
||||
if err != nil && sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Invalidate(sd.SQL)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return err
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
||||
return fmt.Errorf("expected statement description, got %T", results)
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
@ -1036,27 +1178,28 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return err
|
||||
}
|
||||
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
||||
}
|
||||
return fmt.Errorf("expected sync, got %T", results)
|
||||
}
|
||||
|
||||
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Put(sd)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
for _, bi := range b.queuedQueries {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||
for _, bi := range b.QueuedQueries {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
// we wrap the error so we the user can understand which query failed inside the batch
|
||||
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
if bi.sd.Name == "" {
|
||||
@ -1068,7 +1211,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
return &pipelineBatchResults{
|
||||
@ -1100,7 +1243,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
||||
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
|
||||
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
|
||||
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
|
||||
// typeName must be one of the following:
|
||||
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
|
||||
// - A composite type name where all field types are already registered.
|
||||
// - A domain type name where the base type is already registered.
|
||||
// - An enum type name.
|
||||
// - A range type name where the element type is already registered.
|
||||
// - A multirange type name where the element type is already registered.
|
||||
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
|
||||
var oid uint32
|
||||
|
||||
@ -1110,8 +1261,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
|
||||
}
|
||||
|
||||
var typtype string
|
||||
var typbasetype uint32
|
||||
|
||||
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
|
||||
err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1136,8 +1288,39 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
|
||||
}
|
||||
|
||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
|
||||
case "d": // domain
|
||||
dt, ok := c.TypeMap().TypeForOID(typbasetype)
|
||||
if !ok {
|
||||
return nil, errors.New("domain base type OID not registered")
|
||||
}
|
||||
|
||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
|
||||
case "e": // enum
|
||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
||||
case "r": // range
|
||||
elementOID, err := c.getRangeElementOID(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dt, ok := c.TypeMap().TypeForOID(elementOID)
|
||||
if !ok {
|
||||
return nil, errors.New("range element OID not registered")
|
||||
}
|
||||
|
||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
|
||||
case "m": // multirange
|
||||
elementOID, err := c.getMultiRangeElementOID(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dt, ok := c.TypeMap().TypeForOID(elementOID)
|
||||
if !ok {
|
||||
return nil, errors.New("multirange element OID not registered")
|
||||
}
|
||||
|
||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
|
||||
default:
|
||||
return &pgtype.Type{}, errors.New("unknown typtype")
|
||||
}
|
||||
@ -1154,6 +1337,28 @@ func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, erro
|
||||
return typelem, nil
|
||||
}
|
||||
|
||||
func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
||||
var typelem uint32
|
||||
|
||||
err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return typelem, nil
|
||||
}
|
||||
|
||||
func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
||||
var typelem uint32
|
||||
|
||||
err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return typelem, nil
|
||||
}
|
||||
|
||||
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
|
||||
var typrelid uint32
|
||||
|
||||
@ -1168,6 +1373,8 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com
|
||||
rows, _ := c.Query(ctx, `select attname, atttypid
|
||||
from pg_attribute
|
||||
where attrelid=$1
|
||||
and not attisdropped
|
||||
and attnum > 0
|
||||
order by attnum`,
|
||||
typrelid,
|
||||
)
|
||||
@ -1187,17 +1394,17 @@ order by attnum`,
|
||||
}
|
||||
|
||||
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||
if c.pgConn.TxStatus() != 'I' {
|
||||
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.descriptionCache != nil {
|
||||
c.descriptionCache.HandleInvalidated()
|
||||
c.descriptionCache.RemoveInvalidated()
|
||||
}
|
||||
|
||||
var invalidatedStatements []*pgconn.StatementDescription
|
||||
if c.statementCache != nil {
|
||||
invalidatedStatements = c.statementCache.HandleInvalidated()
|
||||
invalidatedStatements = c.statementCache.GetInvalidated()
|
||||
}
|
||||
|
||||
if len(invalidatedStatements) == 0 {
|
||||
@ -1221,5 +1428,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
c.statementCache.RemoveInvalidated()
|
||||
for _, sd := range invalidatedStatements {
|
||||
delete(c.preparedStatements, sd.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
55
conn_internal_test.go
Normal file
55
conn_internal_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mustParseConfig(t testing.TB, connString string) *ConnConfig {
|
||||
config, err := ParseConfig(connString)
|
||||
require.Nil(t, err)
|
||||
return config
|
||||
}
|
||||
|
||||
func mustConnect(t testing.TB, config *ConnConfig) *Conn {
|
||||
conn, err := ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// Ensures the connection limits the size of its cached objects.
|
||||
// This test examines the internals of *Conn so must be in the same package.
|
||||
func TestStmtCacheSizeLimit(t *testing.T) {
|
||||
const cacheLimit = 16
|
||||
|
||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
connConfig.StatementCacheCapacity = cacheLimit
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer func() {
|
||||
err := conn.Close(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// run a set of unique queries that should overflow the cache
|
||||
ctx := context.Background()
|
||||
for i := 0; i < cacheLimit*2; i++ {
|
||||
uniqueString := fmt.Sprintf("unique %d", i)
|
||||
uniqueSQL := fmt.Sprintf("select '%s'", uniqueString)
|
||||
var output string
|
||||
err := conn.QueryRow(ctx, uniqueSQL).Scan(&output)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uniqueString, output)
|
||||
}
|
||||
// preparedStatements contains cacheLimit+1 because deallocation happens before the query
|
||||
assert.Len(t, conn.preparedStatements, cacheLimit+1)
|
||||
assert.Equal(t, cacheLimit, conn.statementCache.Len())
|
||||
}
|
400
conn_test.go
400
conn_test.go
@ -3,6 +3,7 @@ package pgx_test
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -197,10 +198,28 @@ func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
connString string
|
||||
expectedErrSubstring string
|
||||
}{
|
||||
{"default_query_exec_mode=does_not_exist", "does_not_exist"},
|
||||
} {
|
||||
config, err := pgx.ParseConfig(tt.connString)
|
||||
require.Nil(t, config)
|
||||
require.ErrorContains(t, err, tt.expectedErrSubstring)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
@ -236,14 +255,17 @@ type testQueryRewriter struct {
|
||||
args []any
|
||||
}
|
||||
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
return qr.sql, qr.args
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return qr.sql, qr.args, nil
|
||||
}
|
||||
|
||||
func TestExecWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||
_, err := conn.Exec(ctx, "should be replaced", &qr)
|
||||
require.NoError(t, err)
|
||||
@ -253,7 +275,10 @@ func TestExecWithQueryRewriter(t *testing.T) {
|
||||
func TestExecFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
@ -269,7 +294,10 @@ func TestExecFailure(t *testing.T) {
|
||||
func TestExecFailureWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), "selct $1;", 1)
|
||||
if err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
@ -284,8 +312,11 @@ func TestExecFailureWithArguments(t *testing.T) {
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
|
||||
@ -302,8 +333,11 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
_, err := conn.Exec(ctx, "selct;")
|
||||
@ -324,8 +358,11 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
_, err := conn.Exec(ctx, "selct $1;", 1)
|
||||
@ -424,7 +461,7 @@ func TestPrepare(t *testing.T) {
|
||||
t.Errorf("Prepared statement did not return expected value: %v", s)
|
||||
}
|
||||
|
||||
err = conn.Deallocate(context.Background(), "test")
|
||||
err = conn.DeallocateAll(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("conn.Deallocate failed: %v", err)
|
||||
}
|
||||
@ -446,9 +483,10 @@ func TestPrepareBadSQLFailure(t *testing.T) {
|
||||
func TestPrepareIdempotency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := conn.Prepare(context.Background(), "test", "select 42::integer")
|
||||
if err != nil {
|
||||
@ -471,12 +509,16 @@ func TestPrepareIdempotency(t *testing.T) {
|
||||
t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareStatementCacheModes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -487,6 +529,91 @@ func TestPrepareStatementCacheModes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareWithDigestedName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
sql := "select $1::text"
|
||||
sd, err := conn.Prepare(ctx, sql, sql)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
|
||||
|
||||
var s string
|
||||
err = conn.QueryRow(ctx, sql, "hello").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
err = conn.Deallocate(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/pull/1795
|
||||
func TestDeallocateInAbortedTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
sql := "select $1::text"
|
||||
sd, err := tx.Prepare(ctx, sql, sql)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
|
||||
|
||||
var s string
|
||||
err = tx.QueryRow(ctx, sql, "hello").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
_, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error
|
||||
require.Error(t, err)
|
||||
|
||||
err = conn.Deallocate(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Rollback(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
sd, err = conn.Prepare(ctx, sql, sql)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeallocateMissingPreparedStatementStillClearsFromPreparedStatementMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Prepare(ctx, "ps", "select $1::text")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Exec(ctx, "deallocate ps")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Deallocate(ctx, "ps")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Prepare(ctx, "ps", "select $1::text, $2::text")
|
||||
require.NoError(t, err)
|
||||
|
||||
var s1, s2 string
|
||||
err = conn.QueryRow(ctx, "ps", "hello", "world").Scan(&s1, &s2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s1)
|
||||
require.Equal(t, "world", s2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenNotify(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -526,6 +653,7 @@ func TestListenNotify(t *testing.T) {
|
||||
defer cancel()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.Nil(t, notification)
|
||||
|
||||
// listener can listen again after a timeout
|
||||
mustExec(t, notifier, "notify chat")
|
||||
@ -545,6 +673,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
|
||||
listenerDone := make(chan bool)
|
||||
notifierDone := make(chan bool)
|
||||
listening := make(chan bool)
|
||||
go func() {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
@ -553,6 +682,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
}()
|
||||
|
||||
mustExec(t, conn, "listen busysafe")
|
||||
listening <- true
|
||||
|
||||
for i := 0; i < 5000; i++ {
|
||||
var sum int32
|
||||
@ -575,7 +705,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("conn.Query failed: %v", err)
|
||||
t.Errorf("conn.Query failed: %v", rows.Err())
|
||||
return
|
||||
}
|
||||
|
||||
@ -588,8 +718,6 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
t.Errorf("Wrong number of rows: %v", rowCount)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -600,9 +728,10 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
notifierDone <- true
|
||||
}()
|
||||
|
||||
<-listening
|
||||
|
||||
for i := 0; i < 100000; i++ {
|
||||
mustExec(t, conn, "notify busysafe, 'hello'")
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -715,7 +844,10 @@ func TestFatalTxError(t *testing.T) {
|
||||
func TestInsertBoolArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
@ -730,7 +862,10 @@ func TestInsertBoolArray(t *testing.T) {
|
||||
func TestInsertTimestampArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
@ -812,7 +947,10 @@ func TestConnInitTypeMap(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
|
||||
var n uint64
|
||||
@ -828,7 +966,10 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDomainType(t *testing.T) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
|
||||
// Domain type uint64 is a PostgreSQL domain of underlying type numeric.
|
||||
@ -837,24 +978,21 @@ func TestDomainType(t *testing.T) {
|
||||
// uint64 but a result OID of the underlying numeric.
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
|
||||
err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "24", s)
|
||||
|
||||
// Register type
|
||||
var uint64OID uint32
|
||||
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
|
||||
if err != nil {
|
||||
t.Fatalf("did not find uint64 OID, %v", err)
|
||||
}
|
||||
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
|
||||
uint64Type, err := conn.LoadType(ctx, "uint64")
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(uint64Type)
|
||||
|
||||
var n uint64
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
||||
err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
|
||||
// String is still an acceptable argument after registration
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
|
||||
err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -865,7 +1003,10 @@ func TestDomainType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
@ -906,6 +1047,111 @@ create type pgx_b.point as (c text);
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadCompositeType(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, "create type compositetype as (attr1 int, attr2 int)")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.Exec(ctx, "alter type compositetype drop attribute attr1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.LoadType(ctx, "compositetype")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadRangeType(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi)")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register types
|
||||
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(newRangeType)
|
||||
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
|
||||
|
||||
var inputRangeType = pgtype.Range[float64]{
|
||||
Lower: 1.0,
|
||||
Upper: 2.0,
|
||||
LowerType: pgtype.Inclusive,
|
||||
UpperType: pgtype.Inclusive,
|
||||
Valid: true,
|
||||
}
|
||||
var outputRangeType pgtype.Range[float64]
|
||||
err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, inputRangeType, outputRangeType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadMultiRangeType(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
|
||||
pgxtest.SkipPostgreSQLVersionLessThan(t, conn, 14) // multirange data type was added in 14 postgresql
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register types
|
||||
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(newRangeType)
|
||||
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
|
||||
|
||||
newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange")
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(newMultiRangeType)
|
||||
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
|
||||
|
||||
var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
|
||||
{
|
||||
Lower: 1.0,
|
||||
Upper: 2.0,
|
||||
LowerType: pgtype.Inclusive,
|
||||
UpperType: pgtype.Inclusive,
|
||||
Valid: true,
|
||||
},
|
||||
{
|
||||
Lower: 3.0,
|
||||
Upper: 4.0,
|
||||
LowerType: pgtype.Exclusive,
|
||||
UpperType: pgtype.Exclusive,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]]
|
||||
err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, inputMultiRangeType, outputMultiRangeType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStmtCacheInvalidationConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@ -1048,7 +1294,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInsertDurationInterval(t *testing.T) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -1082,7 +1331,7 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
if bytes.Compare(original, buf) != 0 {
|
||||
if !bytes.Equal(original, buf) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -1090,3 +1339,82 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
||||
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847
|
||||
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division")
|
||||
|
||||
var n int32
|
||||
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
|
||||
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
|
||||
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
|
||||
// we could call conn.statementCache.InvalidateAll() instead.
|
||||
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
|
||||
require.Error(t, err)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
cancel2()
|
||||
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847
|
||||
func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
connString := os.Getenv("PGX_TEST_DATABASE")
|
||||
config := mustParseConfig(t, connString)
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 2
|
||||
|
||||
conn, err := pgx.ConnectConfig(ctx, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, "select $1::int + 1", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.Exec(ctx, "select $1::int + 2", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// This should invalidate the first cached statement.
|
||||
_, err = tx.Exec(ctx, "select $1::int + 3", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select $1::int + 1", 1)
|
||||
err = tx.SendBatch(ctx, batch).Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Rollback(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestErrNoRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// ensure we preserve old error message
|
||||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||
|
||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||
}
|
||||
|
80
copy_from.go
80
copy_from.go
@ -64,6 +64,33 @@ func (cts *copyFromSlice) Err() error {
|
||||
return cts.err
|
||||
}
|
||||
|
||||
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
|
||||
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
|
||||
// or it returns an error. If nxtf returns an error, the copy is aborted.
|
||||
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
|
||||
return ©FromFunc{next: nxtf}
|
||||
}
|
||||
|
||||
type copyFromFunc struct {
|
||||
next func() ([]any, error)
|
||||
valueRow []any
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Next() bool {
|
||||
g.valueRow, g.err = g.next()
|
||||
// only return true if valueRow exists and no error
|
||||
return g.valueRow != nil && g.err == nil
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Values() ([]any, error) {
|
||||
return g.valueRow, g.err
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Err() error {
|
||||
return g.err
|
||||
}
|
||||
|
||||
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
||||
type CopyFromSource interface {
|
||||
// Next returns true if there is another row and makes the next row data
|
||||
@ -85,6 +112,7 @@ type copyFrom struct {
|
||||
columnNames []string
|
||||
rowSrc CopyFromSource
|
||||
readerErrChan chan error
|
||||
mode QueryExecMode
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
@ -105,9 +133,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
}
|
||||
quotedColumnNames := cbuf.String()
|
||||
|
||||
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
var sd *pgconn.StatementDescription
|
||||
switch ct.mode {
|
||||
case QueryExecModeExec, QueryExecModeSimpleProtocol:
|
||||
// These modes don't support the binary format. Before the inclusion of the
|
||||
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
|
||||
// the OIDs. These prepared statements were not cached.
|
||||
//
|
||||
// Since that's the same behavior provided by QueryExecModeDescribeExec,
|
||||
// we'll default to that mode.
|
||||
ct.mode = QueryExecModeDescribeExec
|
||||
fallthrough
|
||||
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
|
||||
var err error
|
||||
sd, err = ct.conn.getStatementDescription(
|
||||
ctx,
|
||||
ct.mode,
|
||||
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("statement description failed: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
|
||||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
@ -167,8 +215,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
||||
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
|
||||
lastBufLen := 0
|
||||
largestRowLen := 0
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
lastBufLen = len(buf)
|
||||
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
@ -185,7 +238,15 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) > 65536 {
|
||||
rowLen := len(buf) - lastBufLen
|
||||
if rowLen > largestRowLen {
|
||||
largestRowLen = rowLen
|
||||
}
|
||||
|
||||
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
|
||||
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
|
||||
// 13, 65531, 13, 65531, 13.
|
||||
if len(buf) > sendBufSize-largestRowLen {
|
||||
return true, buf, nil
|
||||
}
|
||||
}
|
||||
@ -193,12 +254,14 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
|
||||
return false, buf, nil
|
||||
}
|
||||
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||
// It returns the number of rows copied and an error.
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
|
||||
// an error.
|
||||
//
|
||||
// CopyFrom requires all values use the binary format. Almost all types
|
||||
// implemented by pgx use the binary format by default. Types implementing
|
||||
// Encoder can only be used if they encode to the binary format.
|
||||
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
|
||||
// for the type of each column. Almost all types implemented by pgx support the binary format.
|
||||
//
|
||||
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
|
||||
// Conn.LoadType and pgtype.Map.RegisterType.
|
||||
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||
ct := ©From{
|
||||
conn: c,
|
||||
@ -206,6 +269,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
|
||||
columnNames: columnNames,
|
||||
rowSrc: rowSrc,
|
||||
readerErrChan: make(chan error),
|
||||
mode: c.config.DefaultQueryExecMode,
|
||||
}
|
||||
|
||||
return ct.run(ctx)
|
||||
|
@ -14,9 +14,137 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnCopyWithAllQueryExecModes(t *testing.T) {
|
||||
for _, mode := range pgxtest.AllQueryExecModes {
|
||||
t.Run(mode.String(), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
cfg.DefaultQueryExecMode = mode
|
||||
conn := mustConnect(t, cfg)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d text,
|
||||
e timestamptz
|
||||
)`)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", tzedTime},
|
||||
{nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
|
||||
|
||||
for _, mode := range pgxtest.KnownOIDQueryExecModes {
|
||||
t.Run(mode.String(), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
cfg.DefaultQueryExecMode = mode
|
||||
conn := mustConnect(t, cfg)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz
|
||||
)`)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCopyFromSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -37,7 +165,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
@ -45,7 +173,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -73,6 +201,9 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||
func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -93,7 +224,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
|
||||
pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
|
||||
return inputRows[i], nil
|
||||
}))
|
||||
@ -104,7 +235,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -132,11 +263,12 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
func TestConnCopyFromLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
@ -156,7 +288,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
@ -164,7 +296,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -192,10 +324,12 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
func TestConnCopyFromEnum(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx := context.Background()
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
@ -220,7 +354,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
conn.TypeMap().RegisterType(typ)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `create table foo(
|
||||
_, err = tx.Exec(ctx, `create temporary table foo(
|
||||
a text,
|
||||
b color,
|
||||
c fruit,
|
||||
@ -235,11 +369,11 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
{nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
|
||||
copyCount, err := tx.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
rows, err := tx.Query(ctx, "select * from foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
var outputRows [][]any
|
||||
@ -255,12 +389,18 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
err = tx.Rollback(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -280,7 +420,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
@ -288,7 +428,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -338,6 +478,9 @@ func (cfs *clientFailSource) Err() error {
|
||||
func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -352,7 +495,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
{int32(3), "def"},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
@ -363,7 +506,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -414,6 +557,9 @@ func (fs *failSource) Err() error {
|
||||
func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -425,7 +571,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
@ -442,7 +588,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -491,6 +637,9 @@ func (fs *slowFailRaceSource) Err() error {
|
||||
func TestConnCopyFromSlowFailRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -499,7 +648,7 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
|
||||
b bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
@ -516,6 +665,9 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
|
||||
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -523,7 +675,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
@ -531,7 +683,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -576,6 +728,9 @@ func (cfs *clientFinalErrSource) Err() error {
|
||||
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@ -583,7 +738,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
@ -591,7 +746,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
@ -615,3 +770,125 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int8
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{"42"},
|
||||
{"7"},
|
||||
{8},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, _ := conn.Query(ctx, "select * from foo")
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, []int64{42, 7, 8}, nums)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/discussions/1891
|
||||
func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a numeric[]
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{[]string{"42"}},
|
||||
{[]string{"7"}},
|
||||
{[]string{"8", "9"}},
|
||||
{[][]string{{"10", "11"}, {"12", "13"}}},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
// Test reads as int64 and flattened array for simplicity.
|
||||
rows, _ := conn.Query(ctx, "select * from foo")
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestCopyFromFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int
|
||||
)`)
|
||||
|
||||
dataCh := make(chan int, 1)
|
||||
|
||||
const channelItems = 10
|
||||
go func() {
|
||||
for i := 0; i < channelItems; i++ {
|
||||
dataCh <- i
|
||||
}
|
||||
close(dataCh)
|
||||
}()
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
|
||||
pgx.CopyFromFunc(func() ([]any, error) {
|
||||
v, ok := <-dataCh
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return []any{v}, nil
|
||||
}))
|
||||
|
||||
require.ErrorIs(t, err, nil)
|
||||
require.EqualValues(t, channelItems, copyCount)
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo order by a")
|
||||
require.NoError(t, err)
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)
|
||||
|
||||
// simulate a failure
|
||||
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
|
||||
pgx.CopyFromFunc(func() func() ([]any, error) {
|
||||
x := 9
|
||||
return func() ([]any, error) {
|
||||
x++
|
||||
if x > 100 {
|
||||
return nil, fmt.Errorf("simulated error")
|
||||
}
|
||||
return []any{x}, nil
|
||||
}
|
||||
}()))
|
||||
require.NotErrorIs(t, err, nil)
|
||||
require.EqualValues(t, 0, copyCount) // no change, due to error
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
256
derived_types.go
Normal file
256
derived_types.go
Normal file
@ -0,0 +1,256 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
/*
|
||||
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
|
||||
|
||||
pgVersion: the major version of the PostgreSQL server
|
||||
typeNames: the names of the types to load. If nil, load all types.
|
||||
*/
|
||||
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
|
||||
supportsMultirange := (pgVersion >= 14)
|
||||
var typeNamesClause string
|
||||
|
||||
if typeNames == nil {
|
||||
// This should not occur; this will not return any types
|
||||
typeNamesClause = "= ''"
|
||||
} else {
|
||||
typeNamesClause = "= ANY($1)"
|
||||
}
|
||||
parts := make([]string, 0, 10)
|
||||
|
||||
// Each of the type names provided might be found in pg_class or pg_type.
|
||||
// Additionally, it may or may not include a schema portion.
|
||||
parts = append(parts, `
|
||||
WITH RECURSIVE
|
||||
-- find the OIDs in pg_class which match one of the provided type names
|
||||
selected_classes(oid,reltype) AS (
|
||||
-- this query uses the namespace search path, so will match type names without a schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_catalog.pg_class
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
|
||||
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
|
||||
AND relname `, typeNamesClause, `
|
||||
UNION ALL
|
||||
-- this query will only match type names which include the schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_class
|
||||
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
|
||||
WHERE nspname || '.' || relname `, typeNamesClause, `
|
||||
),
|
||||
selected_types(oid) AS (
|
||||
-- collect the OIDs from pg_types which correspond to the selected classes
|
||||
SELECT reltype AS oid
|
||||
FROM selected_classes
|
||||
UNION ALL
|
||||
-- as well as any other type names which match our criteria
|
||||
SELECT pg_type.oid
|
||||
FROM pg_type
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE typname `, typeNamesClause, `
|
||||
OR nspname || '.' || typname `, typeNamesClause, `
|
||||
),
|
||||
-- this builds a parent/child mapping of objects, allowing us to know
|
||||
-- all the child (ie: dependent) types that a parent (type) requires
|
||||
-- As can be seen, there are 3 ways this can occur (the last of which
|
||||
-- is due to being a composite class, where the composite fields are children)
|
||||
pc(parent, child) AS (
|
||||
SELECT parent.oid, parent.typelem
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtype = 'b' AND parent.typelem != 0
|
||||
UNION ALL
|
||||
SELECT parent.oid, parent.typbasetype
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
|
||||
UNION ALL
|
||||
SELECT pg_type.oid, atttypid
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
),
|
||||
-- Now construct a recursive query which includes a 'depth' element.
|
||||
-- This is used to ensure that the "youngest" children are registered before
|
||||
-- their parents.
|
||||
relationships(parent, child, depth) AS (
|
||||
SELECT DISTINCT 0::OID, selected_types.oid, 0
|
||||
FROM selected_types
|
||||
UNION ALL
|
||||
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
|
||||
FROM selected_classes c
|
||||
inner join pg_type ON (c.reltype = pg_type.oid)
|
||||
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
|
||||
UNION ALL
|
||||
SELECT pc.parent, pc.child, relationships.depth + 1
|
||||
FROM pc
|
||||
INNER JOIN relationships ON (pc.parent = relationships.child)
|
||||
),
|
||||
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
|
||||
composite AS (
|
||||
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
GROUP BY pg_type.oid
|
||||
)
|
||||
-- Bring together this information, showing all the information which might possibly be required
|
||||
-- to complete the registration, applying filters to only show the items which relate to the selected
|
||||
-- types/classes.
|
||||
SELECT typname,
|
||||
pg_namespace.nspname,
|
||||
typtype,
|
||||
typbasetype,
|
||||
typelem,
|
||||
pg_type.oid,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
|
||||
} else {
|
||||
parts = append(parts, `
|
||||
0 AS rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
|
||||
attnames, atttypids
|
||||
FROM relationships
|
||||
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
|
||||
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
|
||||
}
|
||||
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN composite USING (oid)
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE NOT (typtype = 'b' AND typelem = 0)`)
|
||||
parts = append(parts, `
|
||||
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
multirange.rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
attnames, atttypids
|
||||
ORDER BY MAX(depth) desc, typname;`)
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
type derivedTypeInfo struct {
|
||||
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
|
||||
TypeName, Typtype, NspName string
|
||||
Attnames []string
|
||||
Atttypids []uint32
|
||||
}
|
||||
|
||||
// LoadTypes performs a single (complex) query, returning all the required
|
||||
// information to register the named types, as well as any other types directly
|
||||
// or indirectly required to complete the registration.
|
||||
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||
m := c.TypeMap()
|
||||
if len(typeNames) == 0 {
|
||||
return nil, fmt.Errorf("No type names were supplied.")
|
||||
}
|
||||
|
||||
// Disregard server version errors. This will result in
|
||||
// the SQL not support recent structures such as multirange
|
||||
serverVersion, _ := serverVersion(c)
|
||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
result := make([]*pgtype.Type, 0, 100)
|
||||
for rows.Next() {
|
||||
ti := derivedTypeInfo{}
|
||||
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While scanning type information: %w", err)
|
||||
}
|
||||
var type_ *pgtype.Type
|
||||
switch ti.Typtype {
|
||||
case "b": // array
|
||||
dt, ok := m.TypeForOID(ti.Typelem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
|
||||
}
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
|
||||
case "c": // composite
|
||||
var fields []pgtype.CompositeCodecField
|
||||
for i, fieldName := range ti.Attnames {
|
||||
dt, ok := m.TypeForOID(ti.Atttypids[i])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
|
||||
}
|
||||
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
|
||||
case "d": // domain
|
||||
dt, ok := m.TypeForOID(ti.Typbasetype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
|
||||
case "e": // enum
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
|
||||
case "r": // range
|
||||
dt, ok := m.TypeForOID(ti.Rngsubtype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
|
||||
case "m": // multirange
|
||||
dt, ok := m.TypeForOID(ti.Rngtypid)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||
}
|
||||
|
||||
// the type_ is imposible to be null
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// serverVersion returns the postgresql server version.
|
||||
func serverVersion(c *Conn) (int64, error) {
|
||||
serverVersionStr := c.PgConn().ParameterStatus("server_version")
|
||||
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
|
||||
// if not PostgreSQL do nothing
|
||||
if serverVersionStr == "" {
|
||||
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
|
||||
}
|
||||
|
||||
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
|
||||
}
|
||||
return version, nil
|
||||
}
|
40
derived_types_test.go
Normal file
40
derived_types_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(ctx, `
|
||||
drop type if exists dtype_test;
|
||||
drop domain if exists anotheruint64;
|
||||
|
||||
create domain anotheruint64 as numeric(20,0);
|
||||
create type dtype_test as (
|
||||
a text,
|
||||
b int4,
|
||||
c anotheruint64,
|
||||
d anotheruint64[]
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(ctx, "drop type dtype_test")
|
||||
defer conn.Exec(ctx, "drop domain anotheruint64")
|
||||
|
||||
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, types, 6)
|
||||
require.Equal(t, types[0].Name, "public.anotheruint64")
|
||||
require.Equal(t, types[1].Name, "anotheruint64")
|
||||
require.Equal(t, types[2].Name, "public._anotheruint64")
|
||||
require.Equal(t, types[3].Name, "_anotheruint64")
|
||||
require.Equal(t, types[4].Name, "public.dtype_test")
|
||||
require.Equal(t, types[5].Name, "dtype_test")
|
||||
})
|
||||
}
|
39
doc.go
39
doc.go
@ -7,24 +7,25 @@ details.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
The primary way of establishing a connection is with `pgx.Connect`.
|
||||
The primary way of establishing a connection is with [pgx.Connect]:
|
||||
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
|
||||
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
|
||||
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
|
||||
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
|
||||
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
|
||||
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
|
||||
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
|
||||
string.
|
||||
|
||||
Connection Pool
|
||||
|
||||
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package
|
||||
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
|
||||
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
|
||||
|
||||
Query Interface
|
||||
|
||||
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
|
||||
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and
|
||||
rows.Err().
|
||||
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
|
||||
rows.Scan, and rows.Err().
|
||||
|
||||
CollectRows can be used collect all returned rows into a slice.
|
||||
|
||||
@ -40,7 +41,7 @@ directly.
|
||||
|
||||
var sum, n int32
|
||||
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
_, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error {
|
||||
_, err := pgx.ForEachRow(rows, []any{&n}, func() error {
|
||||
sum += n
|
||||
return nil
|
||||
})
|
||||
@ -69,8 +70,9 @@ Use Exec to execute a query that does not return a result set.
|
||||
|
||||
PostgreSQL Data Types
|
||||
|
||||
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values
|
||||
including array and composite types. See that package's documentation for details.
|
||||
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
|
||||
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
|
||||
require type registration. See that package's documentation for details.
|
||||
|
||||
Transactions
|
||||
|
||||
@ -97,7 +99,8 @@ Transactions are started by calling Begin.
|
||||
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
|
||||
These are internally implemented with savepoints.
|
||||
|
||||
Use BeginTx to control the transaction mode.
|
||||
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
|
||||
a pseudo nested transaction.
|
||||
|
||||
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
|
||||
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
||||
@ -160,17 +163,19 @@ notification is received or the context is canceled.
|
||||
|
||||
_, err := conn.Exec(context.Background(), "listen channelname")
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
if notification, err := conn.WaitForNotification(context.Background()); err != nil {
|
||||
// do something with notification
|
||||
notification, err := conn.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// do something with notification
|
||||
|
||||
|
||||
Tracing and Logging
|
||||
|
||||
pgx supports tracing by setting ConnConfig.Tracer.
|
||||
pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
|
||||
|
||||
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
|
||||
|
||||
@ -178,12 +183,12 @@ For debug tracing of the actual PostgreSQL wire protocol messages see github.com
|
||||
|
||||
Lower Level PostgreSQL Functionality
|
||||
|
||||
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in
|
||||
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
|
||||
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
|
||||
|
||||
PgBouncer
|
||||
|
||||
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
|
||||
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
|
||||
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
|
||||
*/
|
||||
package pgx
|
||||
|
@ -2,7 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -29,7 +29,7 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
func putUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
id := req.URL.Path
|
||||
var url string
|
||||
if body, err := ioutil.ReadAll(req.Body); err == nil {
|
||||
if body, err := io.ReadAll(req.Body); err == nil {
|
||||
url = string(body)
|
||||
} else {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
|
@ -3,7 +3,6 @@ package pgx
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
@ -22,10 +21,15 @@ type ExtendedQueryBuilder struct {
|
||||
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
|
||||
eqb.reset()
|
||||
|
||||
anynil.NormalizeSlice(args)
|
||||
|
||||
if sd == nil {
|
||||
return eqb.appendParamsForQueryExecModeExec(m, args)
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
@ -35,7 +39,7 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
|
||||
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -51,14 +55,33 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
|
||||
// must be an untyped nil.
|
||||
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
|
||||
if format == -1 {
|
||||
format = eqb.chooseParameterFormatCode(m, oid, arg)
|
||||
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
|
||||
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
|
||||
if preferredErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var otherFormat int16
|
||||
if preferredFormat == TextFormatCode {
|
||||
otherFormat = BinaryFormatCode
|
||||
} else {
|
||||
otherFormat = TextFormatCode
|
||||
}
|
||||
|
||||
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
|
||||
if otherErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return preferredErr // return the error from the preferred format
|
||||
}
|
||||
eqb.ParamFormats = append(eqb.ParamFormats, format)
|
||||
|
||||
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eqb.ParamFormats = append(eqb.ParamFormats, format)
|
||||
eqb.ParamValues = append(eqb.ParamValues, v)
|
||||
|
||||
return nil
|
||||
@ -93,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
|
||||
}
|
||||
|
||||
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
|
||||
if anynil.Is(arg) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if eqb.paramValueBytes == nil {
|
||||
eqb.paramValueBytes = make([]byte, 0, 128)
|
||||
}
|
||||
@ -125,61 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
|
||||
|
||||
return m.FormatCodeForOID(oid)
|
||||
}
|
||||
|
||||
// appendParamsForQueryExecModeExec appends the args to eqb.
|
||||
//
|
||||
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
|
||||
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
|
||||
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
|
||||
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
|
||||
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
|
||||
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
|
||||
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
|
||||
// type conversion it takes the date directly and ignores time zone (i.e. it works).
|
||||
//
|
||||
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
|
||||
// no way to safely use binary or to specify the parameter OIDs.
|
||||
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
err := eqb.appendParam(m, 0, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
dt, ok := m.TypeForValue(arg)
|
||||
if !ok {
|
||||
var tv pgtype.TextValuer
|
||||
if tv, ok = arg.(pgtype.TextValuer); ok {
|
||||
t, err := tv.TextValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = t
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
var str fmt.Stringer
|
||||
if str, ok = arg.(fmt.Stringer); ok {
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = str.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
|
||||
}
|
||||
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
17
go.mod
17
go.mod
@ -1,20 +1,21 @@
|
||||
module github.com/jackc/pgx/v5
|
||||
|
||||
go 1.18
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
|
||||
github.com/jackc/puddle/v2 v2.0.0-beta.1
|
||||
github.com/stretchr/testify v1.8.0
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
|
||||
golang.org/x/text v0.3.7
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
|
||||
github.com/jackc/puddle/v2 v2.2.2
|
||||
github.com/stretchr/testify v1.8.1
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/text v0.24.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
41
go.sum
41
go.sum
@ -1,36 +1,45 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||
github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8=
|
||||
github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -79,7 +79,7 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
t.Fatalf("conn.Query failed: %v", rows.Err())
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
|
@ -1,36 +0,0 @@
|
||||
package anynil
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
|
||||
func Is(value any) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(value)
|
||||
switch refVal.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||
return refVal.IsNil()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
|
||||
func Normalize(v any) any {
|
||||
if Is(v) {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
|
||||
// mutated in place.
|
||||
func NormalizeSlice(s []any) {
|
||||
for i := range s {
|
||||
if Is(s[i]) {
|
||||
s[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
@ -1,4 +1,7 @@
|
||||
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
|
||||
//
|
||||
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
|
||||
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
|
||||
package iobufpool
|
||||
|
||||
import "sync"
|
||||
@ -10,17 +13,27 @@ var pools [18]*sync.Pool
|
||||
func init() {
|
||||
for i := range pools {
|
||||
bufLen := 1 << (minPoolExpOf2 + i)
|
||||
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }}
|
||||
pools[i] = &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufLen)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets a []byte of len size with cap <= size*2.
|
||||
func Get(size int) []byte {
|
||||
func Get(size int) *[]byte {
|
||||
i := getPoolIdx(size)
|
||||
if i >= len(pools) {
|
||||
return make([]byte, size)
|
||||
buf := make([]byte, size)
|
||||
return &buf
|
||||
}
|
||||
return pools[i].Get().([]byte)[:size]
|
||||
|
||||
ptrBuf := (pools[i].Get().(*[]byte))
|
||||
*ptrBuf = (*ptrBuf)[:size]
|
||||
|
||||
return ptrBuf
|
||||
}
|
||||
|
||||
func getPoolIdx(size int) int {
|
||||
@ -36,8 +49,8 @@ func getPoolIdx(size int) int {
|
||||
}
|
||||
|
||||
// Put returns buf to the pool.
|
||||
func Put(buf []byte) {
|
||||
i := putPoolIdx(cap(buf))
|
||||
func Put(buf *[]byte) {
|
||||
i := putPoolIdx(cap(*buf))
|
||||
if i < 0 {
|
||||
return
|
||||
}
|
||||
|
@ -30,15 +30,15 @@ func TestGetCap(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
buf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(*buf), "bad len for requestedLen: %d", len(*buf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(*buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
|
||||
putBuf := make([]byte, putBufSize)
|
||||
iobufpool.Put(putBuf)
|
||||
iobufpool.Put(&putBuf)
|
||||
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
@ -62,8 +62,8 @@ func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
getBuf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(*getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(*getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -73,10 +73,10 @@ func TestPutGetBufferReuse(t *testing.T) {
|
||||
// it not to be. So try many times.
|
||||
for i := 0; i < 100000; i++ {
|
||||
buf := iobufpool.Get(4)
|
||||
buf[0] = 1
|
||||
(*buf)[0] = 1
|
||||
iobufpool.Put(buf)
|
||||
buf = iobufpool.Get(4)
|
||||
if buf[0] == 1 {
|
||||
if (*buf)[0] == 1 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -1,70 +0,0 @@
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const minBufferQueueLen = 8
|
||||
|
||||
type bufferQueue struct {
|
||||
lock sync.Mutex
|
||||
queue [][]byte
|
||||
r, w int
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushBack(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
bq.queue[bq.w] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushFront(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
|
||||
bq.queue[bq.r] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) popFront() []byte {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.r == bq.w {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := bq.queue[bq.r]
|
||||
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
|
||||
bq.r++
|
||||
|
||||
if bq.r == bq.w {
|
||||
bq.r = 0
|
||||
bq.w = 0
|
||||
if len(bq.queue) > minBufferQueueLen {
|
||||
bq.queue = make([][]byte, minBufferQueueLen)
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) growQueue() {
|
||||
desiredLen := (len(bq.queue) + 1) * 3 / 2
|
||||
if desiredLen < minBufferQueueLen {
|
||||
desiredLen = minBufferQueueLen
|
||||
}
|
||||
|
||||
newQueue := make([][]byte, desiredLen)
|
||||
copy(newQueue, bq.queue)
|
||||
bq.queue = newQueue
|
||||
}
|
@ -1,536 +0,0 @@
|
||||
// Package nbconn implements a non-blocking net.Conn wrapper.
|
||||
//
|
||||
// It is designed to solve three problems.
|
||||
//
|
||||
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||
//
|
||||
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
var ErrWouldBlock = new(wouldBlockError)
|
||||
|
||||
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||
|
||||
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
|
||||
// mode.
|
||||
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
|
||||
|
||||
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
|
||||
// ignore all future calls.
|
||||
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
|
||||
|
||||
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
|
||||
type wouldBlockError struct{}
|
||||
|
||||
func (*wouldBlockError) Error() string {
|
||||
return "would block"
|
||||
}
|
||||
|
||||
func (*wouldBlockError) Timeout() bool { return true }
|
||||
func (*wouldBlockError) Temporary() bool { return true }
|
||||
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
|
||||
// the underlying connection.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
|
||||
// Flush flushes any buffered writes.
|
||||
Flush() error
|
||||
|
||||
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
|
||||
BufferReadUntilBlock() error
|
||||
}
|
||||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
type NetConn struct {
|
||||
conn net.Conn
|
||||
rawConn syscall.RawConn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
||||
readFlushLock sync.Mutex
|
||||
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
|
||||
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
|
||||
nonblockWriteBuf []byte
|
||||
nonblockWriteErr error
|
||||
nonblockWriteN int
|
||||
|
||||
readDeadlineLock sync.Mutex
|
||||
readDeadline time.Time
|
||||
readNonblocking bool
|
||||
|
||||
writeDeadlineLock sync.Mutex
|
||||
writeDeadline time.Time
|
||||
|
||||
// Only access with atomics
|
||||
closed int64 // 0 = not closed, 1 = closed
|
||||
}
|
||||
|
||||
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||
nc := &NetConn{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
if !fakeNonBlockingIO {
|
||||
if sc, ok := conn.(syscall.Conn); ok {
|
||||
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||
nc.rawConn = rawConn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *NetConn) Read(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for n < len(b) {
|
||||
buf := c.readQueue.popFront()
|
||||
if buf == nil {
|
||||
break
|
||||
}
|
||||
copiedN := copy(b[n:], buf)
|
||||
if copiedN < len(buf) {
|
||||
buf = buf[copiedN:]
|
||||
c.readQueue.pushFront(buf)
|
||||
} else {
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
n += copiedN
|
||||
}
|
||||
|
||||
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
|
||||
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var readNonblocking bool
|
||||
c.readDeadlineLock.Lock()
|
||||
readNonblocking = c.readNonblocking
|
||||
c.readDeadlineLock.Unlock()
|
||||
|
||||
var readN int
|
||||
if readNonblocking {
|
||||
readN, err = c.nonblockingRead(b[n:])
|
||||
} else {
|
||||
readN, err = c.conn.Read(b[n:])
|
||||
}
|
||||
n += readN
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||
// closed. Call Flush to actually write to the underlying connection.
|
||||
func (c *NetConn) Write(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
buf := iobufpool.Get(len(b))
|
||||
copy(buf, b)
|
||||
c.writeQueue.pushBack(buf)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *NetConn) Close() (err error) {
|
||||
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
|
||||
if !swapped {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := c.conn.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *NetConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||
func (c *NetConn) SetDeadline(t time.Time) error {
|
||||
err := c.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||
func (c *NetConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
if c.readDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.readDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
if t == NonBlockingDeadline {
|
||||
c.readNonblocking = true
|
||||
t = time.Time{}
|
||||
} else {
|
||||
c.readNonblocking = false
|
||||
}
|
||||
|
||||
c.readDeadline = t
|
||||
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
if c.writeDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
c.writeDeadline = t
|
||||
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) Flush() error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
return c.flush()
|
||||
}
|
||||
|
||||
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||
func (c *NetConn) flush() error {
|
||||
var stopChan chan struct{}
|
||||
var errChan chan error
|
||||
|
||||
defer func() {
|
||||
if stopChan != nil {
|
||||
select {
|
||||
case stopChan <- struct{}{}:
|
||||
case <-errChan:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
|
||||
remainingBuf := buf
|
||||
for len(remainingBuf) > 0 {
|
||||
n, err := c.nonblockingWrite(remainingBuf)
|
||||
remainingBuf = remainingBuf[n:]
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
buf = buf[:len(remainingBuf)]
|
||||
copy(buf, remainingBuf)
|
||||
c.writeQueue.pushFront(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Writing was blocked. Reading might unblock it.
|
||||
if stopChan == nil {
|
||||
stopChan, errChan = c.bufferNonblockingRead()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
stopChan = nil
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) BufferReadUntilBlock() error {
|
||||
for {
|
||||
buf := iobufpool.Get(8 * 1024)
|
||||
n, err := c.nonblockingRead(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
c.readQueue.pushBack(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrWouldBlock) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
stopChan = make(chan struct{})
|
||||
errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := c.BufferReadUntilBlock()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stopChan, errChan
|
||||
}
|
||||
|
||||
func (c *NetConn) isClosed() bool {
|
||||
closed := atomic.LoadInt64(&c.closed)
|
||||
return closed == 1
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
} else {
|
||||
return c.realNonblockingWrite(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||
err = c.conn.SetWriteDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetWriteDeadline(c.writeDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
|
||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.nonblockWriteBuf = b
|
||||
c.nonblockWriteN = 0
|
||||
c.nonblockWriteErr = nil
|
||||
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
|
||||
return true
|
||||
})
|
||||
n = c.nonblockWriteN
|
||||
if err == nil && c.nonblockWriteErr != nil {
|
||||
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = c.nonblockWriteErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingRead(b)
|
||||
} else {
|
||||
return c.realNonblockingRead(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
|
||||
err = c.conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetReadDeadline(c.readDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
var funcErr error
|
||||
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||
n, funcErr = syscall.Read(int(fd), b)
|
||||
return true
|
||||
})
|
||||
if err == nil && funcErr != nil {
|
||||
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = funcErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// syscall read did not return an error and 0 bytes were read means EOF.
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// syscall.Conn is interface
|
||||
|
||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||
//
|
||||
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
|
||||
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
|
||||
// *TLSConn is returned.
|
||||
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
|
||||
tc := tls.Client(conn, config)
|
||||
err := tc.Handshake()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure last written part of Handshake is actually sent.
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TLSConn{
|
||||
tlsConn: tc,
|
||||
nbConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
|
||||
// tls.Conn.
|
||||
type TLSConn struct {
|
||||
tlsConn *tls.Conn
|
||||
nbConn *NetConn
|
||||
}
|
||||
|
||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
|
||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||
|
||||
func (tc *TLSConn) Close() error {
|
||||
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
|
||||
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
|
||||
// own 5 second deadline then make all set deadlines no-op.
|
||||
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
|
||||
|
||||
return tc.tlsConn.Close()
|
||||
}
|
||||
|
||||
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
|
||||
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
|
||||
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }
|
@ -1,584 +0,0 @@
|
||||
package nbconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test keys generated with:
|
||||
//
|
||||
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
|
||||
|
||||
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
|
||||
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
|
||||
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
|
||||
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
|
||||
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
|
||||
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
|
||||
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
|
||||
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
|
||||
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
|
||||
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
|
||||
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
|
||||
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
|
||||
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
|
||||
E7ZYmaKTMOhvkg==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
|
||||
// source code.
|
||||
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
|
||||
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
|
||||
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
|
||||
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
|
||||
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
|
||||
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
|
||||
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
|
||||
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
|
||||
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
|
||||
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
|
||||
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
|
||||
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
|
||||
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
|
||||
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
|
||||
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
|
||||
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
|
||||
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
|
||||
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
|
||||
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
|
||||
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
|
||||
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
|
||||
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
|
||||
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
|
||||
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
|
||||
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
||||
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
useTLS bool
|
||||
fakeNonBlockingIO bool
|
||||
}{
|
||||
{
|
||||
name: "Pipe",
|
||||
makeConns: makePipeConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, remote := tt.makeConns(t)
|
||||
|
||||
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||
// uses remote it may be garbage collected leading to the connection being closed.
|
||||
defer local.Close()
|
||||
defer remote.Close()
|
||||
|
||||
var conn nbconn.Conn
|
||||
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||
|
||||
if tt.useTLS {
|
||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsServer := tls.Server(remote, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
serverTLSHandshakeChan := make(chan error)
|
||||
go func() {
|
||||
err := tlsServer.Handshake()
|
||||
serverTLSHandshakeChan <- err
|
||||
}()
|
||||
|
||||
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
|
||||
require.NoError(t, err)
|
||||
conn = tlsConn
|
||||
|
||||
err = <-serverTLSHandshakeChan
|
||||
require.NoError(t, err)
|
||||
remote = tlsServer
|
||||
} else {
|
||||
conn = netConn
|
||||
}
|
||||
|
||||
f(t, conn, remote)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
|
||||
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
|
||||
func makePipeConns(t *testing.T) (local, remote net.Conn) {
|
||||
local, remote = net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
local.Close()
|
||||
remote.Close()
|
||||
})
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
|
||||
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type acceptResultT struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
acceptChan := make(chan acceptResultT)
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
acceptChan <- acceptResultT{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
local, err = net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
acceptResult := <-acceptChan
|
||||
require.NoError(t, acceptResult.err)
|
||||
|
||||
remote = acceptResult.conn
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
func TestWriteIsBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := conn.Flush()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err = remote.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetWriteDeadline(time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
|
||||
_, err = remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("okay"), readBuf)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||
// large values.
|
||||
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
remoteWriteBuf := make([]byte, deadlockSize)
|
||||
_, err := remote.Write(remoteWriteBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Flush()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "i/o timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
|
||||
require.EqualValues(t, 0, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBufferNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.BufferReadUntilBlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err = conn.BufferReadUntilBlock()
|
||||
if !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
require.Equal(t, []byte("okay"), buf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 5)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 10)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf[:n])
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 2)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
require.Equal(t, []byte("al"), readBuf)
|
||||
|
||||
readBuf = make([]byte, 3)
|
||||
n, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("pha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
flushCompleteChan := make(chan struct{})
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-flushCompleteChan
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
close(flushCompleteChan)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
|
||||
err = <-errChan
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
@ -23,7 +23,7 @@ func TestScript(t *testing.T) {
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
pgproto3.FieldDescription{
|
||||
{
|
||||
Name: []byte("?column?"),
|
||||
TableOID: 0,
|
||||
TableAttributeNumber: 0,
|
||||
@ -69,9 +69,7 @@ func TestScript(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
|
60
internal/sanitize/benchmmark.sh
Normal file
60
internal/sanitize/benchmmark.sh
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commmit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
@ -4,8 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@ -18,44 +20,81 @@ type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
|
||||
// character. utf8.RuneError is not an error if it is also width 3.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/1380
|
||||
const replacementcharacterwidth = 3
|
||||
|
||||
const maxBufSize = 16384 // 16 Ki
|
||||
|
||||
var bufPool = &pool[*bytes.Buffer]{
|
||||
new: func() *bytes.Buffer {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
reset: func(b *bytes.Buffer) bool {
|
||||
n := b.Len()
|
||||
b.Reset()
|
||||
return n < maxBufSize
|
||||
},
|
||||
}
|
||||
|
||||
var null = []byte("null")
|
||||
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
buf := bufPool.get()
|
||||
defer bufPool.put(buf)
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
buf.WriteString(part)
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
var p []byte
|
||||
if argIdx < 0 {
|
||||
return "", fmt.Errorf("first sql argument must be > 0")
|
||||
}
|
||||
|
||||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
}
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
p = null
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||
case time.Time:
|
||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
p = arg.Truncate(time.Microsecond).
|
||||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
buf.Write(p)
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
@ -67,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
query := &Query{}
|
||||
query.init(sql)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var sqlLexerPool = &pool[*sqlLexer]{
|
||||
new: func() *sqlLexer {
|
||||
return &sqlLexer{}
|
||||
},
|
||||
reset: func(sl *sqlLexer) bool {
|
||||
*sl = sqlLexer{}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (q *Query) init(sql string) {
|
||||
parts := q.Parts[:0]
|
||||
if parts == nil {
|
||||
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||
parts = make([]Part, 0, n)
|
||||
}
|
||||
|
||||
l := sqlLexerPool.get()
|
||||
defer sqlLexerPool.put(l)
|
||||
|
||||
l.src = sql
|
||||
l.stateFn = rawState
|
||||
l.parts = parts
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
q.Parts = l.parts
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
func QuoteString(dst []byte, str string) []byte {
|
||||
const quote = '\''
|
||||
|
||||
// Preallocate space for the worst case scenario
|
||||
dst = slices.Grow(dst, len(str)*2+2)
|
||||
|
||||
// Add opening quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
// Iterate through the string without allocating
|
||||
for i := 0; i < len(str); i++ {
|
||||
if str[i] == quote {
|
||||
dst = append(dst, quote, quote)
|
||||
} else {
|
||||
dst = append(dst, str[i])
|
||||
}
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
// Add closing quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func QuoteBytes(dst, buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return append(dst, `'\x'`...)
|
||||
}
|
||||
|
||||
// Calculate required length
|
||||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||
|
||||
// Ensure dst has enough capacity
|
||||
if cap(dst)-len(dst) < requiredLen {
|
||||
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||
copy(newDst, dst)
|
||||
dst = newDst
|
||||
}
|
||||
|
||||
// Record original length and extend slice
|
||||
origLen := len(dst)
|
||||
dst = dst[:origLen+requiredLen]
|
||||
|
||||
// Add prefix
|
||||
dst[origLen] = '\''
|
||||
dst[origLen+1] = '\\'
|
||||
dst[origLen+2] = 'x'
|
||||
|
||||
// Encode bytes directly into dst
|
||||
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||
|
||||
// Add suffix
|
||||
dst[len(dst)-1] = '\''
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
@ -138,6 +250,7 @@ func rawState(l *sqlLexer) stateFn {
|
||||
return multilineCommentState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -146,6 +259,7 @@ func rawState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
@ -160,6 +274,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -168,6 +283,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
@ -182,6 +298,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -190,6 +307,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
@ -228,6 +346,7 @@ func escapeStringState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -236,6 +355,7 @@ func escapeStringState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
@ -249,6 +369,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
case '\n', '\r':
|
||||
return rawState
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -257,6 +378,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func multilineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
@ -283,6 +405,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||
l.nested--
|
||||
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
@ -291,14 +414,47 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var queryPool = &pool[*Query]{
|
||||
new: func() *Query {
|
||||
return &Query{}
|
||||
},
|
||||
reset: func(q *Query) bool {
|
||||
n := len(q.Parts)
|
||||
q.Parts = q.Parts[:0]
|
||||
return n < 64 // drop too large queries
|
||||
},
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := queryPool.get()
|
||||
query.init(sql)
|
||||
defer queryPool.put(query)
|
||||
|
||||
return query.Sanitize(args...)
|
||||
}
|
||||
|
||||
type pool[E any] struct {
|
||||
p sync.Pool
|
||||
new func() E
|
||||
reset func(E) bool
|
||||
}
|
||||
|
||||
func (pool *pool[E]) get() E {
|
||||
v, ok := pool.p.Get().(E)
|
||||
if !ok {
|
||||
v = pool.new()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *pool[E]) put(v E) {
|
||||
if p.reset(v) {
|
||||
p.p.Put(v)
|
||||
}
|
||||
}
|
||||
|
62
internal/sanitize/sanitize_bench_test.go
Normal file
62
internal/sanitize/sanitize_bench_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
// sanitize_benchmark_test.go
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
var benchmarkSanitizeResult string
|
||||
|
||||
const benchmarkQuery = "" +
|
||||
`SELECT *
|
||||
FROM "water_containers"
|
||||
WHERE NOT "id" = $1 -- int64
|
||||
AND "tags" NOT IN $2 -- nil
|
||||
AND "volume" > $3 -- float64
|
||||
AND "transportable" = $4 -- bool
|
||||
AND position($5 IN "sign") -- bytes
|
||||
AND "label" LIKE $6 -- string
|
||||
AND "created_at" > $7; -- time.Time`
|
||||
|
||||
var benchmarkArgs = []any{
|
||||
int64(12345),
|
||||
nil,
|
||||
float64(500),
|
||||
true,
|
||||
[]byte("8BADF00D"),
|
||||
"kombucha's han'dy awokowa",
|
||||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
func BenchmarkSanitize(b *testing.B) {
|
||||
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create query: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var benchmarkNewSQLResult string
|
||||
|
||||
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
55
internal/sanitize/sanitize_fuzz_test.go
Normal file
55
internal/sanitize/sanitize_fuzz_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func FuzzQuoteString(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add("new\nline")
|
||||
f.Add("sample text")
|
||||
f.Add("sample q'u'o't'e's")
|
||||
f.Add("select 'quoted $42', $1")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQuoteBytes(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add([]byte(nil))
|
||||
f.Add([]byte("\n"))
|
||||
f.Add([]byte("sample text"))
|
||||
f.Add([]byte("sample q'u'o't'e's"))
|
||||
f.Add([]byte("select 'quoted $42', $1"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input []byte) {
|
||||
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -88,6 +90,16 @@ func TestNewQuery(t *testing.T) {
|
||||
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
|
||||
},
|
||||
{
|
||||
// https://github.com/jackc/pgx/issues/1380
|
||||
sql: "select 'hello w<>rld'",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w<>rld'"}},
|
||||
},
|
||||
{
|
||||
// Unterminated quoted string
|
||||
sql: "select 'hello world",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
@ -164,6 +176,16 @@ func TestQuerySanitize(t *testing.T) {
|
||||
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||
expected: `insert '2020-03-01 23:59:59.999999Z' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{int64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{float64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
@ -207,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
tc := func(name, input string) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteString(nil, input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("empty", "")
|
||||
tc("text", "abcd")
|
||||
tc("with quotes", `one's hat is always a cat`)
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
}
|
||||
|
||||
func TestQuoteBytes(t *testing.T) {
|
||||
tc := func(name string, input []byte) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteBytes(nil, input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("nil", nil)
|
||||
tc("empty", []byte{})
|
||||
tc("text", []byte("abcd"))
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
@ -34,7 +34,8 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
|
||||
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
|
||||
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
@ -44,6 +45,13 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||
return
|
||||
}
|
||||
|
||||
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
|
||||
for _, invalidSD := range c.invalidStmts {
|
||||
if invalidSD.SQL == sd.SQL {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
c.invalidateOldest()
|
||||
}
|
||||
@ -73,10 +81,16 @@ func (c *LRUCache) InvalidateAll() {
|
||||
c.l = list.New()
|
||||
}
|
||||
|
||||
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *LRUCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
|
@ -2,18 +2,17 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var stmtCounter int64
|
||||
|
||||
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
|
||||
func NextStatementName() string {
|
||||
n := atomic.AddInt64(&stmtCounter, 1)
|
||||
return "stmtcache_" + strconv.FormatInt(n, 10)
|
||||
// StatementName returns a statement name that will be stable for sql across multiple connections and program
|
||||
// executions.
|
||||
func StatementName(sql string) string {
|
||||
digest := sha256.Sum256([]byte(sql))
|
||||
return "stmtcache_" + hex.EncodeToString(digest[0:24])
|
||||
}
|
||||
|
||||
// Cache caches statement descriptions.
|
||||
@ -30,8 +29,13 @@ type Cache interface {
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
InvalidateAll()
|
||||
|
||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||
HandleInvalidated() []*pgconn.StatementDescription
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
GetInvalidated() []*pgconn.StatementDescription
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
RemoveInvalidated()
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
@ -39,19 +43,3 @@ type Cache interface {
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
}
|
||||
|
||||
func IsStatementInvalid(err error) bool {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1162
|
||||
//
|
||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||
// have so it should be safe.
|
||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||
return possibleInvalidCachedPlanError
|
||||
}
|
||||
|
@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
|
||||
c.m = make(map[string]*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *UnlimitedCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
|
@ -4,8 +4,15 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
|
||||
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
|
||||
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
|
||||
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
|
||||
|
||||
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
|
||||
// was created.
|
||||
//
|
||||
@ -68,32 +75,65 @@ type LargeObject struct {
|
||||
|
||||
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
|
||||
func (o *LargeObject) Write(p []byte) (int, error) {
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
var n int
|
||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
|
||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return nTotal, err
|
||||
}
|
||||
|
||||
if n < 0 {
|
||||
return 0, errors.New("failed to write to large object")
|
||||
return nTotal, errors.New("failed to write to large object")
|
||||
}
|
||||
|
||||
return n, nil
|
||||
nTotal += n
|
||||
|
||||
if n < expected {
|
||||
return nTotal, errors.New("short write to large object")
|
||||
} else if n > expected {
|
||||
return nTotal, errors.New("invalid write to large object")
|
||||
}
|
||||
}
|
||||
|
||||
return nTotal, nil
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p returning the number of bytes read.
|
||||
func (o *LargeObject) Read(p []byte) (int, error) {
|
||||
var res []byte
|
||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
|
||||
copy(p, res)
|
||||
if err != nil {
|
||||
return len(res), err
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
if len(res) < len(p) {
|
||||
err = io.EOF
|
||||
res := pgtype.PreallocBytes(p[nTotal:])
|
||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
|
||||
// We compute expected so that it always fits into p, so it should never happen
|
||||
// that PreallocBytes's ScanBytes had to allocate a new slice.
|
||||
nTotal += len(res)
|
||||
if err != nil {
|
||||
return nTotal, err
|
||||
}
|
||||
return len(res), err
|
||||
|
||||
if len(res) < expected {
|
||||
return nTotal, io.EOF
|
||||
} else if len(res) > expected {
|
||||
return nTotal, errors.New("invalid read of large object")
|
||||
}
|
||||
}
|
||||
|
||||
return nTotal, nil
|
||||
}
|
||||
|
||||
// Seek moves the current location pointer to the new location specified by offset.
|
||||
|
20
large_objects_private_test.go
Normal file
20
large_objects_private_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
|
||||
// to the given length for the duration of the test.
|
||||
//
|
||||
// Tests using this helper should not use t.Parallel().
|
||||
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
|
||||
t.Helper()
|
||||
|
||||
original := maxLargeObjectMessageLength
|
||||
t.Cleanup(func() {
|
||||
maxLargeObjectMessageLength = original
|
||||
})
|
||||
|
||||
maxLargeObjectMessageLength = length
|
||||
}
|
@ -13,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
func TestLargeObjects(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
@ -34,9 +35,10 @@ func TestLargeObjects(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
@ -160,9 +162,10 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
|
||||
}
|
||||
|
||||
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
152
multitracer/tracer.go
Normal file
152
multitracer/tracer.go
Normal file
@ -0,0 +1,152 @@
|
||||
// Package multitracer provides a Tracer that can combine several tracers into one.
|
||||
package multitracer
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Tracer can combine several tracers into one.
|
||||
// You can use New to automatically split tracers by interface.
|
||||
type Tracer struct {
|
||||
QueryTracers []pgx.QueryTracer
|
||||
BatchTracers []pgx.BatchTracer
|
||||
CopyFromTracers []pgx.CopyFromTracer
|
||||
PrepareTracers []pgx.PrepareTracer
|
||||
ConnectTracers []pgx.ConnectTracer
|
||||
PoolAcquireTracers []pgxpool.AcquireTracer
|
||||
PoolReleaseTracers []pgxpool.ReleaseTracer
|
||||
}
|
||||
|
||||
// New returns new Tracer from tracers with automatically split tracers by interface.
|
||||
func New(tracers ...pgx.QueryTracer) *Tracer {
|
||||
var t Tracer
|
||||
|
||||
for _, tracer := range tracers {
|
||||
t.QueryTracers = append(t.QueryTracers, tracer)
|
||||
|
||||
if batchTracer, ok := tracer.(pgx.BatchTracer); ok {
|
||||
t.BatchTracers = append(t.BatchTracers, batchTracer)
|
||||
}
|
||||
|
||||
if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok {
|
||||
t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer)
|
||||
}
|
||||
|
||||
if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok {
|
||||
t.PrepareTracers = append(t.PrepareTracers, prepareTracer)
|
||||
}
|
||||
|
||||
if connectTracer, ok := tracer.(pgx.ConnectTracer); ok {
|
||||
t.ConnectTracers = append(t.ConnectTracers, connectTracer)
|
||||
}
|
||||
|
||||
if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok {
|
||||
t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer)
|
||||
}
|
||||
|
||||
if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok {
|
||||
t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer)
|
||||
}
|
||||
}
|
||||
|
||||
return &t
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
for _, tracer := range t.QueryTracers {
|
||||
ctx = tracer.TraceQueryStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
for _, tracer := range t.QueryTracers {
|
||||
tracer.TraceQueryEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
ctx = tracer.TraceBatchStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
tracer.TraceBatchQuery(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
tracer.TraceBatchEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
for _, tracer := range t.CopyFromTracers {
|
||||
ctx = tracer.TraceCopyFromStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
for _, tracer := range t.CopyFromTracers {
|
||||
tracer.TraceCopyFromEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
for _, tracer := range t.PrepareTracers {
|
||||
ctx = tracer.TracePrepareStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
for _, tracer := range t.PrepareTracers {
|
||||
tracer.TracePrepareEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
for _, tracer := range t.ConnectTracers {
|
||||
ctx = tracer.TraceConnectStart(ctx, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
for _, tracer := range t.ConnectTracers {
|
||||
tracer.TraceConnectEnd(ctx, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
|
||||
for _, tracer := range t.PoolAcquireTracers {
|
||||
ctx = tracer.TraceAcquireStart(ctx, pool, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
|
||||
for _, tracer := range t.PoolAcquireTracers {
|
||||
tracer.TraceAcquireEnd(ctx, pool, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
|
||||
for _, tracer := range t.PoolReleaseTracers {
|
||||
tracer.TraceRelease(pool, data)
|
||||
}
|
||||
}
|
115
multitracer/tracer_test.go
Normal file
115
multitracer/tracer_test.go
Normal file
@ -0,0 +1,115 @@
|
||||
package multitracer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/multitracer"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testFullTracer struct{}
|
||||
|
||||
func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
|
||||
}
|
||||
|
||||
type testCopyTracer struct{}
|
||||
|
||||
func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fullTracer := &testFullTracer{}
|
||||
copyTracer := &testCopyTracer{}
|
||||
|
||||
mt := multitracer.New(fullTracer, copyTracer)
|
||||
require.Equal(
|
||||
t,
|
||||
&multitracer.Tracer{
|
||||
QueryTracers: []pgx.QueryTracer{
|
||||
fullTracer,
|
||||
copyTracer,
|
||||
},
|
||||
BatchTracers: []pgx.BatchTracer{
|
||||
fullTracer,
|
||||
},
|
||||
CopyFromTracers: []pgx.CopyFromTracer{
|
||||
fullTracer,
|
||||
copyTracer,
|
||||
},
|
||||
PrepareTracers: []pgx.PrepareTracer{
|
||||
fullTracer,
|
||||
},
|
||||
ConnectTracers: []pgx.ConnectTracer{
|
||||
fullTracer,
|
||||
},
|
||||
PoolAcquireTracers: []pgxpool.AcquireTracer{
|
||||
fullTracer,
|
||||
},
|
||||
PoolReleaseTracers: []pgxpool.ReleaseTracer{
|
||||
fullTracer,
|
||||
},
|
||||
},
|
||||
mt,
|
||||
)
|
||||
}
|
@ -2,6 +2,7 @@ package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@ -12,12 +13,43 @@ import (
|
||||
//
|
||||
// For example, the following two queries are equivalent:
|
||||
//
|
||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
|
||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
|
||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
|
||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
|
||||
//
|
||||
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
|
||||
// letters, numbers, or underscores.
|
||||
type NamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(na, sql, false)
|
||||
}
|
||||
|
||||
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||
// named arguments that the sql query uses, and no extra arguments.
|
||||
type StrictNamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(sna, sql, true)
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
@ -41,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
||||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
newArgs[ordinal-1] = na[string(name)]
|
||||
var found bool
|
||||
newArgs[ordinal-1], found = na[string(name)]
|
||||
if isStrict && !found {
|
||||
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), newArgs
|
||||
if isStrict {
|
||||
for name := range na {
|
||||
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
return sb.String(), newArgs, nil
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
@ -80,7 +109,7 @@ func rawState(l *sqlLexer) stateFn {
|
||||
return doubleQuoteState
|
||||
case '@':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if isLetter(nextRune) {
|
||||
if isLetter(nextRune) || nextRune == '_' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
@ -37,10 +38,10 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
},
|
||||
{
|
||||
sql: "select @Abc::int, @b_4::text",
|
||||
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"},
|
||||
expectedSQL: "select $1::int, $2::text",
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
sql: "select @Abc::int, @b_4::text, @_c::int",
|
||||
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
|
||||
expectedSQL: "select $1::int, $2::text, $3::int",
|
||||
expectedArgs: []any{int32(42), "foo", int32(1)},
|
||||
},
|
||||
{
|
||||
sql: "at end @",
|
||||
@ -49,15 +50,15 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "ignores without letter after @ foo bar",
|
||||
sql: "ignores without valid character after @ foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "ignores without letter after @ foo bar",
|
||||
expectedSQL: "ignores without valid character after @ foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "name must start with letter @1 foo bar",
|
||||
sql: "name cannot start with number @1 foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "name must start with letter @1 foo bar",
|
||||
expectedSQL: "name cannot start with number @1 foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
@ -92,11 +93,70 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||
expectedSQL: "extra provided argument",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.NamedArgs{},
|
||||
expectedSQL: "$1 argument",
|
||||
expectedArgs: []any{nil},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
namedArgs pgx.StrictNamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
isExpectedError bool
|
||||
}{
|
||||
{
|
||||
sql: "no arguments",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
expectedSQL: "no arguments",
|
||||
expectedArgs: []any{},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "@all @matches",
|
||||
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||
expectedSQL: "$1 $2",
|
||||
expectedArgs: []any{int32(1), int32(2)},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||
isExpectedError: true,
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
isExpectedError: true,
|
||||
},
|
||||
} {
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||
if tt.isExpectedError {
|
||||
assert.Errorf(t, err, "%d", i)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -26,28 +26,4 @@ if err != nil {
|
||||
|
||||
## Testing
|
||||
|
||||
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||
environment variable handling.
|
||||
|
||||
### Example Test Environment
|
||||
|
||||
Connect to your PostgreSQL server and run:
|
||||
|
||||
```
|
||||
create database pgx_test;
|
||||
```
|
||||
|
||||
Now you can run the tests:
|
||||
|
||||
```bash
|
||||
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||
```
|
||||
|
||||
### Connection and Authentication Tests
|
||||
|
||||
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||
authentication code.
|
||||
See CONTRIBUTING.md for setup instructions.
|
||||
|
@ -42,12 +42,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
// Receive server-first-message payload in an AuthenticationSASLContinue.
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -62,12 +62,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
c.frontend.Send(saslResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
// Receive server-final-message payload in an AuthenticationSASLFinal.
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -53,7 +53,7 @@ func BenchmarkExec(b *testing.B) {
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
@ -97,7 +97,7 @@ func BenchmarkExec(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
@ -159,7 +159,7 @@ func BenchmarkExecPrepared(b *testing.B) {
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
@ -197,7 +197,7 @@ func BenchmarkExecPrepared(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
@ -238,7 +238,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||
}
|
||||
|
||||
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
|
||||
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
// require.Nil(b, err)
|
||||
// defer closeConn(b, conn)
|
||||
|
||||
|
226
pgconn/config.go
226
pgconn/config.go
@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
@ -20,6 +19,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgservicefile"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type GetSSLPasswordFunc func(ctx context.Context) string
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
|
||||
// manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
@ -40,12 +40,19 @@ type Config struct {
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
|
||||
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
|
||||
// when a context passed to a PgConn method is canceled.
|
||||
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
|
||||
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
KerberosSrvName string
|
||||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
@ -61,12 +68,17 @@ type Config struct {
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
|
||||
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
|
||||
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
|
||||
// that you close on FATAL errors by returning false.
|
||||
OnPgError PgErrorHandler
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
|
||||
type ParseConfigOptions struct {
|
||||
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
|
||||
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
|
||||
// PQsetSSLKeyPassHook_OpenSSL.
|
||||
GetSSLPassword GetSSLPasswordFunc
|
||||
}
|
||||
@ -108,6 +120,14 @@ type FallbackConfig struct {
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// connectOneConfig is the configuration for a single attempt to connect to a single host.
|
||||
type connectOneConfig struct {
|
||||
network string
|
||||
address string
|
||||
originalHostname string // original hostname before resolving
|
||||
tlsConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// isAbsolutePath checks if the provided value is an absolute path either
|
||||
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||
@ -142,11 +162,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
|
||||
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
|
||||
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
|
||||
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
|
||||
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
//
|
||||
// # Example DSN
|
||||
// # Example Keyword/Value
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
@ -165,7 +185,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||
// via database URL or DSN:
|
||||
// via database URL or keyword/value:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
@ -180,9 +200,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGSSLPASSWORD
|
||||
// PGOPTIONS
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
// PGTZ
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
@ -211,7 +233,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
//
|
||||
// In addition, ParseConfig accepts the following options:
|
||||
//
|
||||
// servicefile
|
||||
// - servicefile.
|
||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||
// part of the connection string.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
@ -229,16 +251,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
connStringSettings := make(map[string]string)
|
||||
if connString != "" {
|
||||
var err error
|
||||
// connString may be a database URL or a DSN
|
||||
// connString may be a database URL or in PostgreSQL keyword/value format
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
connStringSettings, err = parseURLSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
connStringSettings, err = parseDSNSettings(connString)
|
||||
connStringSettings, err = parseKeywordValueSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -247,7 +269,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
if service, present := settings["service"]; present {
|
||||
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
|
||||
}
|
||||
|
||||
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||
@ -262,12 +284,22 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||
return pgproto3.NewFrontend(r, w)
|
||||
},
|
||||
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
|
||||
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
|
||||
},
|
||||
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
|
||||
// we want to automatically close any fatal errors
|
||||
if strings.EqualFold(pgErr.Severity, "FATAL") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.ConnectTimeout = connectTimeout
|
||||
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||
@ -290,7 +322,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
"sslkey": {},
|
||||
"sslcert": {},
|
||||
"sslrootcert": {},
|
||||
"sslnegotiation": {},
|
||||
"sslpassword": {},
|
||||
"sslsni": {},
|
||||
"krbspn": {},
|
||||
"krbsrvname": {},
|
||||
"target_session_attrs": {},
|
||||
@ -328,7 +362,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
@ -340,7 +374,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings, host, options)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
@ -357,6 +391,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
config.SSLNegotiation = settings["sslnegotiation"]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
@ -384,7 +419,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
case "any":
|
||||
// do nothing
|
||||
default:
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
@ -417,11 +452,15 @@ func parseEnvSettings() map[string]string {
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLSNI": "sslsni",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGSSLPASSWORD": "sslpassword",
|
||||
"PGSSLNEGOTIATION": "sslnegotiation",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
"PGTZ": "timezone",
|
||||
"PGOPTIONS": "options",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
@ -437,14 +476,17 @@ func parseEnvSettings() map[string]string {
|
||||
func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
url, err := url.Parse(connString)
|
||||
parsedURL, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
if urlErr := new(url.Error); errors.As(err, &urlErr) {
|
||||
return nil, urlErr.Err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if url.User != nil {
|
||||
settings["user"] = url.User.Username()
|
||||
if password, present := url.User.Password(); present {
|
||||
if parsedURL.User != nil {
|
||||
settings["user"] = parsedURL.User.Username()
|
||||
if password, present := parsedURL.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
@ -452,7 +494,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(url.Host, ",") {
|
||||
for _, host := range strings.Split(parsedURL.Host, ",") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
@ -478,7 +520,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings["port"] = strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(url.Path, "/")
|
||||
database := strings.TrimLeft(parsedURL.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
@ -487,7 +529,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for k, v := range url.Query() {
|
||||
for k, v := range parsedURL.Query() {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
@ -504,7 +546,7 @@ func isIPOnly(host string) bool {
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
func parseDSNSettings(s string) (map[string]string, error) {
|
||||
func parseKeywordValueSettings(s string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
@ -515,7 +557,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
|
||||
var key, val string
|
||||
eqIdx := strings.IndexRune(s, '=')
|
||||
if eqIdx < 0 {
|
||||
return nil, errors.New("invalid dsn")
|
||||
return nil, errors.New("invalid keyword/value")
|
||||
}
|
||||
|
||||
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||
@ -567,7 +609,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return nil, errors.New("invalid dsn")
|
||||
return nil, errors.New("invalid keyword/value")
|
||||
}
|
||||
|
||||
settings[key] = val
|
||||
@ -612,14 +654,56 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
sslpassword := settings["sslpassword"]
|
||||
sslsni := settings["sslsni"]
|
||||
sslnegotiation := settings["sslnegotiation"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
if sslsni == "" {
|
||||
sslsni = "1"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if sslnegotiation == "direct" {
|
||||
tlsConfig.NextProtos = []string{"postgresql"}
|
||||
if sslmode == "prefer" {
|
||||
sslmode = "require"
|
||||
}
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
var caCertPool *x509.CertPool
|
||||
|
||||
if sslrootcert == "system" {
|
||||
var err error
|
||||
|
||||
caCertPool, err = x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
|
||||
}
|
||||
|
||||
sslmode = "verify-full"
|
||||
} else {
|
||||
caCertPool = x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
@ -677,33 +761,19 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
buf, err := ioutil.ReadFile(sslkey)
|
||||
buf, err := os.ReadFile(sslkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode sslkey")
|
||||
}
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
@ -738,7 +808,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
} else {
|
||||
pemKey = pem.EncodeToMemory(block)
|
||||
}
|
||||
certfile, err := ioutil.ReadFile(sslcert)
|
||||
certfile, err := os.ReadFile(sslcert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||
}
|
||||
@ -749,6 +819,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||
// or IPv6).
|
||||
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
@ -773,7 +850,8 @@ func parsePort(s string) (uint16, error) {
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
// rely on GOLANG KeepAlive settings
|
||||
return &net.Dialer{}
|
||||
}
|
||||
|
||||
func makeDefaultResolver() *net.Resolver {
|
||||
@ -797,75 +875,75 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "on" {
|
||||
if string(result[0].Rows[0][0]) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-only.
|
||||
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "on" {
|
||||
if string(result[0].Rows[0][0]) != "on" {
|
||||
return errors.New("connection is not read only")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=standby.
|
||||
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
if string(result[0].Rows[0][0]) != "t" {
|
||||
return errors.New("server is not in hot standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=primary.
|
||||
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "t" {
|
||||
if string(result[0].Rows[0][0]) == "t" {
|
||||
return errors.New("server is in standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=prefer-standby.
|
||||
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
if string(result[0].Rows[0][0]) != "t" {
|
||||
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||
}
|
||||
|
||||
|
@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -17,8 +18,25 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
func skipOnWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("FIXME: skipping on Windows, investigate why this test fails in CI environment")
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultPort(t *testing.T) uint16 {
|
||||
if envPGPORT := os.Getenv("PGPORT"); envPGPORT != "" {
|
||||
p, err := strconv.ParseUint(envPGPORT, 10, 16)
|
||||
require.NoError(t, err)
|
||||
return uint16(p)
|
||||
}
|
||||
return 5432
|
||||
}
|
||||
|
||||
func getDefaultUser(t *testing.T) string {
|
||||
if pguser := os.Getenv("PGUSER"); pguser != "" {
|
||||
return pguser
|
||||
}
|
||||
|
||||
var osUserName string
|
||||
osUser, err := user.Current()
|
||||
@ -32,10 +50,20 @@ func TestParseConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
return osUserName
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
t.Parallel()
|
||||
|
||||
config, err := pgconn.ParseConfig("")
|
||||
require.NoError(t, err)
|
||||
defaultHost := config.Host
|
||||
|
||||
defaultUser := getDefaultUser(t)
|
||||
defaultPort := getDefaultPort(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
@ -53,10 +81,11 @@ func TestParseConfig(t *testing.T) {
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
@ -89,11 +118,12 @@ func TestParseConfig(t *testing.T) {
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -111,10 +141,11 @@ func TestParseConfig(t *testing.T) {
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
@ -133,6 +164,7 @@ func TestParseConfig(t *testing.T) {
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
@ -148,6 +180,7 @@ func TestParseConfig(t *testing.T) {
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
@ -198,7 +231,7 @@ func TestParseConfig(t *testing.T) {
|
||||
name: "database url missing user and password",
|
||||
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
User: defaultUser,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
@ -223,9 +256,9 @@ func TestParseConfig(t *testing.T) {
|
||||
name: "database url unix domain socket host",
|
||||
connString: "postgres:///foo?host=/tmp",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
User: defaultUser,
|
||||
Host: "/tmp",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
@ -235,9 +268,9 @@ func TestParseConfig(t *testing.T) {
|
||||
name: "database url unix domain socket host on windows",
|
||||
connString: "postgres:///foo?host=C:\\tmp",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
User: defaultUser,
|
||||
Host: "C:\\tmp",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
@ -247,9 +280,9 @@ func TestParseConfig(t *testing.T) {
|
||||
name: "database url dbname",
|
||||
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
User: defaultUser,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
@ -297,14 +330,14 @@ func TestParseConfig(t *testing.T) {
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "2001:db8::1",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN everything",
|
||||
name: "Key/value everything",
|
||||
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
@ -321,7 +354,7 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with escaped single quote",
|
||||
name: "Key/value with escaped single quote",
|
||||
connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack's",
|
||||
@ -334,7 +367,7 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with escaped backslash",
|
||||
name: "Key/value with escaped backslash",
|
||||
connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
@ -347,48 +380,48 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with single quoted values",
|
||||
name: "Key/value with single quoted values",
|
||||
connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with single quoted value with escaped single quote",
|
||||
name: "Key/value with single quoted value with escaped single quote",
|
||||
connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'",
|
||||
config: &pgconn.Config{
|
||||
User: "jack's",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with empty single quoted value",
|
||||
name: "Key/value with empty single quoted value",
|
||||
connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN with space between key and value",
|
||||
name: "Key/value with space between key and value",
|
||||
connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
@ -401,19 +434,19 @@ func TestParseConfig(t *testing.T) {
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
@ -431,12 +464,12 @@ func TestParseConfig(t *testing.T) {
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 2,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 3,
|
||||
TLSConfig: nil,
|
||||
@ -459,7 +492,7 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN multiple hosts one port",
|
||||
name: "Key/value multiple hosts one port",
|
||||
connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
@ -470,12 +503,12 @@ func TestParseConfig(t *testing.T) {
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
@ -484,7 +517,7 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN multiple hosts multiple ports",
|
||||
name: "Key/value multiple hosts multiple ports",
|
||||
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
@ -495,12 +528,12 @@ func TestParseConfig(t *testing.T) {
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 2,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 3,
|
||||
TLSConfig: nil,
|
||||
@ -509,44 +542,47 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple hosts and fallback tsl",
|
||||
name: "multiple hosts and fallback tls",
|
||||
connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "foo",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "bar",
|
||||
}},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "baz",
|
||||
}},
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
@ -648,6 +684,82 @@ func TestParseConfig(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI is set by default",
|
||||
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "sni.test",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "sni.test",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI is not set for IPv4",
|
||||
connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "1.1.1.1",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI is not set for IPv6",
|
||||
connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "::1",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI is not set when disabled (URL-style)",
|
||||
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "sni.test",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI is not set when disabled (key/value style)",
|
||||
connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "sni.test",
|
||||
Port: defaultPort,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@ -661,18 +773,18 @@ func TestParseConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgconn/issues/47
|
||||
func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
|
||||
func TestParseConfigKVWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
|
||||
_, err := pgconn.ParseConfig("host= user= password= port= database=")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestParseConfigDSNLeadingEqual(t *testing.T) {
|
||||
func TestParseConfigKVLeadingEqual(t *testing.T) {
|
||||
_, err := pgconn.ParseConfig("= user=jack")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgconn/issues/49
|
||||
func TestParseConfigDSNTrailingBackslash(t *testing.T) {
|
||||
func TestParseConfigKVTrailingBackslash(t *testing.T) {
|
||||
_, err := pgconn.ParseConfig(`x=x\`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid backslash")
|
||||
@ -705,7 +817,7 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
|
||||
connString := os.Getenv("PGX_TEST_CONN_STRING")
|
||||
connString := os.Getenv("PGX_TEST_DATABASE")
|
||||
original, err := pgconn.ParseConfig(connString)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -820,20 +932,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||
|
||||
savedEnv := make(map[string]string)
|
||||
for _, n := range pgEnvvars {
|
||||
savedEnv[n] = os.Getenv(n)
|
||||
}
|
||||
defer func() {
|
||||
for k, v := range savedEnv {
|
||||
err := os.Setenv(k, v)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to restore environment: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI", "PGTZ", "PGOPTIONS"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -853,7 +952,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "123.123.123.123",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
@ -872,6 +971,8 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
"PGCONNECT_TIMEOUT": "10",
|
||||
"PGSSLMODE": "disable",
|
||||
"PGAPPNAME": "pgxtest",
|
||||
"PGTZ": "America/New_York",
|
||||
"PGOPTIONS": "-c search_path=myschema",
|
||||
},
|
||||
config: &pgconn.Config{
|
||||
Host: "123.123.123.123",
|
||||
@ -881,20 +982,31 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
Password: "baz",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest", "timezone": "America/New_York", "options": "-c search_path=myschema"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SNI can be disabled via environment variable",
|
||||
envvars: map[string]string{
|
||||
"PGHOST": "test.foo",
|
||||
"PGSSLMODE": "require",
|
||||
"PGSSLSNI": "0",
|
||||
},
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "test.foo",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
for _, n := range pgEnvvars {
|
||||
err := os.Unsetenv(n)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
for k, v := range tt.envvars {
|
||||
err := os.Setenv(k, v)
|
||||
require.NoError(t, err)
|
||||
for _, env := range pgEnvvars {
|
||||
t.Setenv(env, tt.envvars[env])
|
||||
}
|
||||
|
||||
config, err := pgconn.ParseConfig("")
|
||||
@ -907,18 +1019,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
t.Parallel()
|
||||
|
||||
tf, err := ioutil.TempFile("", "")
|
||||
tfName := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer tf.Close()
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
||||
require.NoError(t, err)
|
||||
|
||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
|
||||
expected := &pgconn.Config{
|
||||
User: "curly",
|
||||
Password: "nyuknyuknyuk",
|
||||
@ -936,15 +1044,12 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseConfigReadsPgServiceFile(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
t.Parallel()
|
||||
|
||||
tf, err := ioutil.TempFile("", "")
|
||||
require.NoError(t, err)
|
||||
tfName := filepath.Join(t.TempDir(), "config")
|
||||
|
||||
defer tf.Close()
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
_, err = tf.Write([]byte(`
|
||||
err := os.WriteFile(tfName, []byte(`
|
||||
[abc]
|
||||
host=abc.example.com
|
||||
port=9999
|
||||
@ -956,9 +1061,11 @@ host = def.example.com
|
||||
dbname = defdb
|
||||
user = defuser
|
||||
application_name = spaced string
|
||||
`))
|
||||
`), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
defaultPort := getDefaultPort(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
@ -966,7 +1073,7 @@ application_name = spaced string
|
||||
}{
|
||||
{
|
||||
name: "abc",
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"),
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "abc"),
|
||||
config: &pgconn.Config{
|
||||
Host: "abc.example.com",
|
||||
Database: "abcdb",
|
||||
@ -974,10 +1081,11 @@ application_name = spaced string
|
||||
Port: 9999,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "abc.example.com",
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "abc.example.com",
|
||||
Port: 9999,
|
||||
TLSConfig: nil,
|
||||
@ -987,20 +1095,21 @@ application_name = spaced string
|
||||
},
|
||||
{
|
||||
name: "def",
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"),
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "def"),
|
||||
config: &pgconn.Config{
|
||||
Host: "def.example.com",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
Database: "defdb",
|
||||
User: "defuser",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "def.example.com",
|
||||
},
|
||||
RuntimeParams: map[string]string{"application_name": "spaced string"},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
{
|
||||
Host: "def.example.com",
|
||||
Port: 5432,
|
||||
Port: defaultPort,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
@ -1008,7 +1117,7 @@ application_name = spaced string
|
||||
},
|
||||
{
|
||||
name: "conn string has precedence",
|
||||
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"),
|
||||
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tfName, "abc"),
|
||||
config: &pgconn.Config{
|
||||
Host: "other.example.com",
|
||||
Database: "abcdb",
|
||||
|
@ -8,8 +8,7 @@ import (
|
||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
onCancel func()
|
||||
onUnwatchAfterCancel func()
|
||||
handler Handler
|
||||
unwatchChan chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
@ -20,10 +19,9 @@ type ContextWatcher struct {
|
||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||
// onCancel called.
|
||||
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||
func NewContextWatcher(handler Handler) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
onCancel: onCancel,
|
||||
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||
handler: handler,
|
||||
unwatchChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.onCancel()
|
||||
cw.handler.HandleCancel(ctx)
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
|
||||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
cw.onUnwatchAfterCancel()
|
||||
cw.handler.HandleUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
|
||||
// context that was canceled.
|
||||
HandleCancel(canceledCtx context.Context)
|
||||
|
||||
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
|
||||
HandleUnwatchAfterCancel()
|
||||
}
|
@ -6,17 +6,32 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
handleCancel func(context.Context)
|
||||
handleUnwatchAfterCancel func()
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleCancel(ctx context.Context) {
|
||||
h.handleCancel(ctx)
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleUnwatchAfterCancel() {
|
||||
h.handleUnwatchAfterCancel()
|
||||
}
|
||||
|
||||
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
canceledChan := make(chan struct{})
|
||||
cleanupCalled := false
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
canceledChan <- struct{}{}
|
||||
}, func() {
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
cleanupCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -34,11 +49,13 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
t.Error("cancel func should not have been called")
|
||||
}, func() {
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
t.Error("cleanup func should not have been called")
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -48,11 +65,12 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
defer cw.Unwatch()
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
@ -60,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
cw.Unwatch() // unwatch when not / never watching
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -71,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
@ -87,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) {
|
||||
var cancelFuncCalls int64
|
||||
var cleanupFuncCalls int64
|
||||
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||
}, func() {
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||
},
|
||||
})
|
||||
|
||||
cycleCount := 100000
|
||||
@ -103,7 +123,9 @@ func TestContextWatcherStress(t *testing.T) {
|
||||
}
|
||||
|
||||
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
|
||||
if i%3 == 0 {
|
||||
if i%333 == 0 {
|
||||
// on Windows Sleep takes more time than expected so we try to get here less frequently to avoid
|
||||
// the CI takes a long time
|
||||
time.Sleep(time.Nanosecond)
|
||||
}
|
||||
|
||||
@ -131,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(context.Background())
|
||||
@ -140,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -151,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
@ -5,8 +5,8 @@ nearly the same level is the C library libpq.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||
libpq style environment variables.
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
|
||||
environment for libpq style environment variables.
|
||||
|
||||
Executing a Query
|
||||
|
||||
@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Pipeline Mode
|
||||
|
||||
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
|
||||
control of exactly how many and when network round trips occur.
|
||||
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
|
||||
exactly how many and when network round trips occur.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
|
||||
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
|
||||
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
|
||||
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
|
||||
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
|
||||
interrupting the query in such a way as to close the connection.
|
||||
|
||||
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||
client to abort.
|
||||
|
@ -12,14 +12,15 @@ import (
|
||||
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||
return e.SafeToRetry()
|
||||
var retryableErr interface{ SafeToRetry() bool }
|
||||
if errors.As(err, &retryableErr) {
|
||||
return retryableErr.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
var timeoutErr *errTimeout
|
||||
return errors.As(err, &timeoutErr)
|
||||
@ -30,6 +31,7 @@ func Timeout(err error) bool {
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
SeverityUnlocalized string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
@ -57,22 +59,37 @@ func (pe *PgError) SQLState() string {
|
||||
return pe.Code
|
||||
}
|
||||
|
||||
type connectError struct {
|
||||
config *Config
|
||||
msg string
|
||||
// ConnectError is the error returned when a connection attempt fails.
|
||||
type ConnectError struct {
|
||||
Config *Config // The configuration that was used in the connection attempt.
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *connectError) Error() string {
|
||||
sb := &strings.Builder{}
|
||||
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||
if e.err != nil {
|
||||
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||
func (e *ConnectError) Error() string {
|
||||
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
|
||||
details := e.err.Error()
|
||||
if strings.Contains(details, "\n") {
|
||||
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
|
||||
} else {
|
||||
return prefix + " " + details
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (e *connectError) Unwrap() error {
|
||||
func (e *ConnectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type perDialConnectError struct {
|
||||
address string
|
||||
originalHostname string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *perDialConnectError) Error() string {
|
||||
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *perDialConnectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
@ -88,29 +105,47 @@ func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
type parseConfigError struct {
|
||||
connString string
|
||||
// ParseConfigError is the error returned when a connection string cannot be parsed.
|
||||
type ParseConfigError struct {
|
||||
ConnString string // The connection string that could not be parsed.
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Error() string {
|
||||
connString := redactPW(e.connString)
|
||||
func NewParseConfigError(conn, msg string, err error) error {
|
||||
return &ParseConfigError{
|
||||
ConnString: conn,
|
||||
msg: msg,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ParseConfigError) Error() string {
|
||||
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
|
||||
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
|
||||
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
|
||||
connString := redactPW(e.ConnString)
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||
}
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Unwrap() error {
|
||||
func (e *ParseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||
// true. Otherwise returns err.
|
||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||
func normalizeTimeoutError(ctx context.Context, err error) error {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
|
||||
return context.Canceled
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
return &errTimeout{err: ctx.Err()}
|
||||
} else {
|
||||
return &errTimeout{err: netErr}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -184,10 +219,10 @@ func redactPW(connString string) string {
|
||||
return redactURL(u)
|
||||
}
|
||||
}
|
||||
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
quotedKV := regexp.MustCompile(`password='[^']*'`)
|
||||
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
plainKV := regexp.MustCompile(`password=[^ ]*`)
|
||||
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||
return connString
|
||||
|
@ -19,18 +19,18 @@ func TestConfigError(t *testing.T) {
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
|
||||
},
|
||||
{
|
||||
name: "dsn with password unquoted",
|
||||
name: "keyword/value with password unquoted",
|
||||
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "dsn with password quoted",
|
||||
name: "keyword/value with password quoted",
|
||||
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "weird url",
|
||||
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
|
||||
err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
|
||||
},
|
||||
{
|
||||
|
@ -1,11 +1,3 @@
|
||||
// File export_test exports some methods for better testing.
|
||||
|
||||
package pgconn
|
||||
|
||||
func NewParseConfigError(conn, msg string, err error) error {
|
||||
return &parseConfigError{
|
||||
connString: conn,
|
||||
msg: msg,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
@ -12,19 +12,19 @@ import (
|
||||
)
|
||||
|
||||
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, conn.Close(ctx))
|
||||
select {
|
||||
case <-conn.CleanupDone():
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("Connection cleanup exceeded maximum time")
|
||||
}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||
cancel()
|
||||
|
||||
|
139
pgconn/internal/bgreader/bgreader.go
Normal file
139
pgconn/internal/bgreader/bgreader.go
Normal file
@ -0,0 +1,139 @@
|
||||
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
|
||||
package bgreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusStopped = iota
|
||||
StatusRunning
|
||||
StatusStopping
|
||||
)
|
||||
|
||||
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
||||
type BGReader struct {
|
||||
r io.Reader
|
||||
|
||||
cond *sync.Cond
|
||||
status int32
|
||||
readResults []readResult
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
buf *[]byte
|
||||
err error
|
||||
}
|
||||
|
||||
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
|
||||
// reader will stop automatically when the underlying reader returns an error.
|
||||
func (r *BGReader) Start() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.status {
|
||||
case StatusStopped:
|
||||
r.status = StatusRunning
|
||||
go r.bgRead()
|
||||
case StatusRunning:
|
||||
// no-op
|
||||
case StatusStopping:
|
||||
r.status = StatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
|
||||
// background reader is not running.
|
||||
func (r *BGReader) Stop() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.status {
|
||||
case StatusStopped:
|
||||
// no-op
|
||||
case StatusRunning:
|
||||
r.status = StatusStopping
|
||||
case StatusStopping:
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the background reader.
|
||||
func (r *BGReader) Status() int32 {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
return r.status
|
||||
}
|
||||
|
||||
func (r *BGReader) bgRead() {
|
||||
keepReading := true
|
||||
for keepReading {
|
||||
buf := iobufpool.Get(8192)
|
||||
n, err := r.r.Read(*buf)
|
||||
*buf = (*buf)[:n]
|
||||
|
||||
r.cond.L.Lock()
|
||||
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
||||
if r.status == StatusStopping || err != nil {
|
||||
r.status = StatusStopped
|
||||
keepReading = false
|
||||
}
|
||||
r.cond.L.Unlock()
|
||||
r.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *BGReader) Read(p []byte) (int, error) {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
if len(r.readResults) > 0 {
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// There are no unread background read results and the background reader is stopped.
|
||||
if r.status == StatusStopped {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
// Wait for results from the background reader
|
||||
for len(r.readResults) == 0 {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
|
||||
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
|
||||
buf := r.readResults[0].buf
|
||||
var err error
|
||||
|
||||
n := copy(p, *buf)
|
||||
if n == len(*buf) {
|
||||
err = r.readResults[0].err
|
||||
iobufpool.Put(buf)
|
||||
if len(r.readResults) == 1 {
|
||||
r.readResults = nil
|
||||
} else {
|
||||
r.readResults = r.readResults[1:]
|
||||
}
|
||||
} else {
|
||||
*buf = (*buf)[n:]
|
||||
r.readResults[0].buf = buf
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func New(r io.Reader) *BGReader {
|
||||
return &BGReader{
|
||||
r: r,
|
||||
cond: &sync.Cond{
|
||||
L: &sync.Mutex{},
|
||||
},
|
||||
}
|
||||
}
|
140
pgconn/internal/bgreader/bgreader_test.go
Normal file
140
pgconn/internal/bgreader/bgreader_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
package bgreader_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBGReaderReadWhenStopped(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWhenStarted(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
type mockReadFunc func(p []byte) (int, error)
|
||||
|
||||
type mockReader struct {
|
||||
readFuncs []mockReadFunc
|
||||
}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (int, error) {
|
||||
if len(r.readFuncs) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
fn := r.readFuncs[0]
|
||||
r.readFuncs = r.readFuncs[1:]
|
||||
|
||||
return fn(p)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
|
||||
},
|
||||
}
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf := make([]byte, 3)
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("foo"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStarted(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStopped(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
type numberReader struct {
|
||||
v uint8
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func (nr *numberReader) Read(p []byte) (int, error) {
|
||||
n := nr.rng.Intn(len(p))
|
||||
for i := 0; i < n; i++ {
|
||||
p[i] = nr.v
|
||||
nr.v++
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
|
||||
// stopping the background worker from other goroutines.
|
||||
func TestBGReaderStress(t *testing.T) {
|
||||
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
|
||||
bgr := bgreader.New(nr)
|
||||
|
||||
bytesRead := 0
|
||||
var expected uint8
|
||||
buf := make([]byte, 10_000)
|
||||
rng := rand.New(rand.NewSource(0))
|
||||
|
||||
for bytesRead < 1_000_000 {
|
||||
randomNumber := rng.Intn(100)
|
||||
switch {
|
||||
case randomNumber < 10:
|
||||
go bgr.Start()
|
||||
case randomNumber < 20:
|
||||
go bgr.Stop()
|
||||
default:
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < n; i++ {
|
||||
require.Equal(t, expected, buf[i])
|
||||
expected++
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
}
|
||||
}
|
@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
|
||||
Data: nextData,
|
||||
}
|
||||
c.frontend.Send(gssResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
1170
pgconn/pgconn.go
1170
pgconn/pgconn.go
File diff suppressed because it is too large
Load Diff
@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestConnStress(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
# pgproto3
|
||||
|
||||
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
|
||||
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||
|
||||
|
@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
|
@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
|
@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 12)
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||
|
||||
for _, s := range src.AuthMechanisms {
|
||||
@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||
|
@ -17,6 +17,7 @@ type Backend struct {
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
encodeError error
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
@ -38,6 +39,7 @@ type Backend struct {
|
||||
terminate Terminate
|
||||
|
||||
bodyLen int
|
||||
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
authType uint32
|
||||
@ -54,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
|
||||
return &Backend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
|
||||
// encountered will be returned from Flush.
|
||||
func (b *Backend) Send(msg BackendMessage) {
|
||||
if b.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(b.wbuf)
|
||||
b.wbuf = msg.Encode(b.wbuf)
|
||||
newBuf, err := msg.Encode(b.wbuf)
|
||||
if err != nil {
|
||||
b.encodeError = err
|
||||
return
|
||||
}
|
||||
b.wbuf = newBuf
|
||||
|
||||
if b.tracer != nil {
|
||||
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||
}
|
||||
@ -66,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
|
||||
|
||||
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||
func (b *Backend) Flush() error {
|
||||
if err := b.encodeError; err != nil {
|
||||
b.encodeError = nil
|
||||
b.wbuf = b.wbuf[:0]
|
||||
return &writeError{err: err, safeToRetry: true}
|
||||
}
|
||||
|
||||
n, err := b.w.Write(b.wbuf)
|
||||
|
||||
const maxLen = 1024
|
||||
@ -157,7 +175,16 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
msgLength := int(binary.BigEndian.Uint32(header[1:]))
|
||||
if msgLength < 4 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLength)
|
||||
}
|
||||
|
||||
b.bodyLen = msgLength - 4
|
||||
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
|
||||
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
|
||||
}
|
||||
b.partialMsg = true
|
||||
}
|
||||
|
||||
@ -196,7 +223,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||
fallthrough
|
||||
default:
|
||||
// to maintain backwards compatability
|
||||
// to maintain backwards compatibility
|
||||
msg = &PasswordMessage{}
|
||||
}
|
||||
case 'Q':
|
||||
@ -260,3 +287,13 @@ func (b *Backend) SetAuthType(authType uint32) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxBodyLen sets the maximum length of a message body in octets.
|
||||
// If a message body exceeds this length, Receive will return an error.
|
||||
// This is useful for protecting against malicious clients that send
|
||||
// large messages with the intent of causing memory exhaustion.
|
||||
// The default value is 0.
|
||||
// If maxBodyLen is 0, then no maximum is enforced.
|
||||
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
|
||||
b.maxBodyLen = maxBodyLen
|
||||
}
|
||||
|
@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
|
||||
"username": "tester",
|
||||
},
|
||||
}
|
||||
dst := []byte{}
|
||||
dst = want.Encode(dst)
|
||||
dst, err := want.Encode([]byte{})
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push(dst)
|
||||
@ -120,3 +120,21 @@ func TestStartupMessage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Q', 0, 0, 10, 10})
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
// Set max body len to 5
|
||||
backend.SetMaxBodyLen(5)
|
||||
|
||||
// Receive regular msg
|
||||
msg, err := backend.Receive()
|
||||
assert.Nil(t, msg)
|
||||
var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
|
||||
assert.ErrorAs(t, err, &invalidBodyLenErr)
|
||||
}
|
||||
|
@ -5,7 +5,9 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *Bind) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'B')
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
if len(src.ParameterFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameter format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
if len(src.Parameters) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameters")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
if len(src.ResultFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many result format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '2', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
20
pgproto3/bind_test.go
Normal file
20
pgproto3/bind_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Maximum allowed size.
|
||||
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1 byte too big
|
||||
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
|
||||
require.Error(t, err)
|
||||
}
|
@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
type chunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
buf *[]byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
minBufSize int
|
||||
@ -45,7 +45,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
|
||||
func (r *chunkReader) Next(n int) (buf []byte, err error) {
|
||||
// Reset the buffer if it is empty
|
||||
if r.rp == r.wp {
|
||||
if len(r.buf) != r.minBufSize {
|
||||
if len(*r.buf) != r.minBufSize {
|
||||
iobufpool.Put(r.buf)
|
||||
r.buf = iobufpool.Get(r.minBufSize)
|
||||
}
|
||||
@ -55,15 +55,15 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
|
||||
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// buf is smaller than requested number of bytes
|
||||
if len(r.buf) < n {
|
||||
if len(*r.buf) < n {
|
||||
bigBuf := iobufpool.Get(n)
|
||||
r.wp = copy(bigBuf, r.buf[r.rp:r.wp])
|
||||
r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
iobufpool.Put(r.buf)
|
||||
r.buf = bigBuf
|
||||
@ -71,20 +71,20 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
r.wp = copy(r.buf, r.buf[r.rp:r.wp])
|
||||
if (len(*r.buf) - r.wp) < minReadCount {
|
||||
r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
}
|
||||
|
||||
// Read at least the required number of bytes from the underlying io.Reader
|
||||
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount)
|
||||
readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount)
|
||||
r.wp += readBytesCount
|
||||
// fmt.Println("read", n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||
if !bytes.Equal(n1, src[0:2]) {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||
}
|
||||
|
||||
@ -25,11 +25,11 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||
if !bytes.Equal(n2, src[2:4]) {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.buf[:len(src)], src) != 0 {
|
||||
if !bytes.Equal((*r.buf)[:len(src)], src) {
|
||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||
}
|
||||
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Close) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '3', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'W')
|
||||
dst = append(dst, src.OverallFormat)
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
|
||||
err := dstResp.Decode(srcBytes[5:])
|
||||
assert.NoError(t, err, "No errors on decode")
|
||||
dstBytes := []byte{}
|
||||
dstBytes = dstResp.Encode(dstBytes)
|
||||
dstBytes, err = dstResp.Encode(dstBytes)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||
}
|
||||
|
@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'd')
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||
return append(dst, 'c', 0, 0, 0, 4)
|
||||
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'c', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyFail struct {
|
||||
@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'f')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'f')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'G')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'H')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
|
||||
if len(src.Values) > math.MaxUint16 {
|
||||
return nil, errors.New("too many values")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Describe) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
//
|
||||
// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are
|
||||
// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call
|
||||
// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call
|
||||
// Flush to ensure a message has actually been sent.
|
||||
//
|
||||
// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a
|
||||
|
@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'I', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -2,7 +2,6 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
@ -111,120 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = src.appendFields(dst)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
func (src *ErrorResponse) appendFields(dst []byte) []byte {
|
||||
if src.Severity != "" {
|
||||
buf.WriteByte('S')
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'S')
|
||||
dst = append(dst, src.Severity...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SeverityUnlocalized != "" {
|
||||
buf.WriteByte('V')
|
||||
buf.WriteString(src.SeverityUnlocalized)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'V')
|
||||
dst = append(dst, src.SeverityUnlocalized...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteByte('C')
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'C')
|
||||
dst = append(dst, src.Code...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteByte('M')
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'M')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteByte('D')
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'D')
|
||||
dst = append(dst, src.Detail...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteByte('H')
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'H')
|
||||
dst = append(dst, src.Hint...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteByte('P')
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'P')
|
||||
dst = append(dst, strconv.Itoa(int(src.Position))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteByte('p')
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'p')
|
||||
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteByte('q')
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'q')
|
||||
dst = append(dst, src.InternalQuery...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteByte('W')
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'W')
|
||||
dst = append(dst, src.Where...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteByte('s')
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 's')
|
||||
dst = append(dst, src.SchemaName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteByte('t')
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 't')
|
||||
dst = append(dst, src.TableName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteByte('c')
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'c')
|
||||
dst = append(dst, src.ColumnName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteByte('d')
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'd')
|
||||
dst = append(dst, src.DataTypeName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteByte('n')
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'n')
|
||||
dst = append(dst, src.ConstraintName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteByte('F')
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'F')
|
||||
dst = append(dst, src.File...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteByte('L')
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'L')
|
||||
dst = append(dst, strconv.Itoa(int(src.Line))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteByte('R')
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'R')
|
||||
dst = append(dst, src.Routine...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, k)
|
||||
dst = append(dst, v...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
|
||||
return fmt.Errorf("error generating query response: %w", err)
|
||||
}
|
||||
|
||||
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: []byte("fortune"),
|
||||
TableOID: 0,
|
||||
@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
|
||||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
}}).Encode(nil)
|
||||
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
|
||||
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
}}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing query response: %w", err)
|
||||
@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
|
||||
|
||||
switch startupMessage.(type) {
|
||||
case *pgproto3.StartupMessage:
|
||||
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending ready for query: %w", err)
|
||||
@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
|
||||
func (p *PgFortuneBackend) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func mustEncode(buf []byte, err error) []byte {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Execute) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user