Skip to content

Commit e452499

Browse files
thecampagnardsKonstantin Sidorenko
authored andcommitted
feat: manage range and multirange
1 parent b7ffbd3 commit e452499

File tree

2 files changed

+491
-0
lines changed

2 files changed

+491
-0
lines changed

range.go

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
package pq
5+
6+
import (
7+
"bytes"
8+
"database/sql"
9+
"database/sql/driver"
10+
"fmt"
11+
"reflect"
12+
"time"
13+
_ "unsafe"
14+
)
15+
16+
// Range returns the optimal driver.Valuer and sql.Scanner for a range.
17+
// Check https://www.postgresql.org/docs/current/rangetypes.html for details.
18+
//
19+
// For example:
20+
//
21+
// min := time.Now()
22+
// max := time.Now().AddDate(0,-1,0)
23+
// db.Query(`SELECT * FROM reservation WHERE during && $1`, pq.NewRange(&min, &max))
24+
//
25+
// var x pq.Range[int]
26+
// db.QueryRow(`SELECT '[1, 10)'`).Scan(&x)
27+
//
28+
// Scanning multi-dimensional range is supported using [MultiRange]
29+
type Range[T any] struct {
30+
// When nil it will be infinite value
31+
Lower, Upper *T
32+
LowerBound RangeLowerBound
33+
UpperBound RangeUpperBound
34+
}
35+
36+
type RangeLowerBound byte
37+
type RangeUpperBound byte
38+
39+
const (
40+
RangeLowerBoundInclusive RangeLowerBound = '['
41+
RangeLowerBoundExclusive RangeLowerBound = '('
42+
RangeUpperBoundInclusive RangeUpperBound = ']'
43+
RangeUpperBoundExclusive RangeUpperBound = ')'
44+
45+
RangeLowerBoundDefault = RangeLowerBoundInclusive
46+
RangeUpperBoundDefault = RangeUpperBoundExclusive
47+
48+
RangeEmpty = "empty"
49+
)
50+
51+
// NewRange create a [Range] with default bounds [RangeLowerBoundDefault] and [RangeUpperBoundDefault].
52+
func NewRange[T any](lower, upper *T) Range[T] {
53+
return Range[T]{
54+
Lower: lower,
55+
Upper: upper,
56+
LowerBound: RangeLowerBoundDefault,
57+
UpperBound: RangeUpperBoundDefault,
58+
}
59+
}
60+
61+
var (
62+
_ sql.Scanner = (*Range[any])(nil)
63+
_ driver.Valuer = (*Range[any])(nil)
64+
)
65+
66+
func (r *Range[T]) Scan(anySrc any) error {
67+
var src []byte
68+
switch s := anySrc.(type) {
69+
case string:
70+
src = []byte(s)
71+
case []byte:
72+
src = s
73+
default:
74+
return fmt.Errorf("pq: cannot convert %T to Range", anySrc)
75+
}
76+
77+
src = bytes.TrimSpace(src)
78+
if len(src) == 0 {
79+
return fmt.Errorf("pq: could not parse range: range is empty")
80+
}
81+
82+
if string(src) == RangeEmpty {
83+
return nil
84+
}
85+
86+
// read bounds
87+
r.LowerBound = RangeLowerBound(src[0])
88+
r.UpperBound = RangeUpperBound(src[len(src)-1])
89+
src = src[1 : len(src)-1]
90+
if len(src) == 0 {
91+
return fmt.Errorf("pq: could not parse range: range is empty")
92+
}
93+
94+
// read range
95+
l, u, ok := bytes.Cut(src, []byte(","))
96+
if !ok {
97+
return fmt.Errorf("pq: could not parse range: missing comma")
98+
}
99+
100+
convertBound := func(dest any, src []byte) error {
101+
src = bytes.Trim(src, "\"")
102+
switch d := dest.(type) {
103+
case sql.Scanner:
104+
if err := d.Scan(src); err != nil {
105+
return err
106+
}
107+
return nil
108+
case *time.Time:
109+
var err error
110+
*d, err = ParseTimestamp(nil, string(src))
111+
if err != nil {
112+
return err
113+
}
114+
return nil
115+
}
116+
if err := convertAssign(dest, string(src)); err != nil {
117+
return err
118+
}
119+
return nil
120+
}
121+
122+
if len(l) != 0 {
123+
r.Lower = new(T)
124+
if err := convertBound(r.Lower, l); err != nil {
125+
return err
126+
}
127+
}
128+
if len(u) != 0 {
129+
r.Upper = new(T)
130+
if err := convertBound(r.Upper, u); err != nil {
131+
return err
132+
}
133+
}
134+
135+
return nil
136+
}
137+
138+
// IsEmpty return true when bounds are inclusive and range value are equal
139+
func (r Range[T]) IsEmpty() bool {
140+
if r.LowerBound == 0 && r.UpperBound == 0 {
141+
return true
142+
}
143+
if r.Lower == nil || r.Upper == nil {
144+
return false
145+
}
146+
if r.LowerBound == RangeLowerBoundExclusive || r.UpperBound == RangeUpperBoundExclusive {
147+
return false
148+
}
149+
return reflect.DeepEqual(*r.Lower, *r.Upper)
150+
}
151+
152+
// IsZero return true when empty, used for IsZeroer interface
153+
func (r Range[T]) IsZero() bool {
154+
return r.IsEmpty()
155+
}
156+
157+
func (r Range[T]) Value() (driver.Value, error) {
158+
if r.IsEmpty() {
159+
return RangeEmpty, nil
160+
}
161+
162+
convertBound := func(src any) (string, error) {
163+
if reflect.ValueOf(src).IsNil() {
164+
return "", nil
165+
}
166+
167+
switch s := src.(type) {
168+
case *time.Time:
169+
return "\"" + string(FormatTimestamp(*s)) + "\"", nil
170+
case driver.Valuer:
171+
v, err := s.Value()
172+
if err != nil {
173+
return "", err
174+
}
175+
var out string
176+
if err := convertAssign(&out, v); err != nil {
177+
return "", err
178+
}
179+
return "\"" + out + "\"", nil
180+
default:
181+
var out string
182+
if err := convertAssign(&out, reflect.ValueOf(src).Elem().Interface()); err != nil {
183+
return "", err
184+
}
185+
return "\"" + out + "\"", nil
186+
}
187+
}
188+
189+
b, err := convertBound(r.Lower)
190+
if err != nil {
191+
return nil, err
192+
}
193+
194+
var v string
195+
v += string(r.LowerBound) + b + ","
196+
197+
b, err = convertBound(r.Upper)
198+
if err != nil {
199+
return nil, err
200+
}
201+
202+
v += b + string(r.UpperBound)
203+
204+
return v, nil
205+
}
206+
207+
//go:linkname convertAssign database/sql.convertAssign
208+
func convertAssign(dest, src any) error
209+
210+
// MultiRange returns the optimal driver.Valuer and sql.Scanner for a multirange.
211+
// Check https://www.postgresql.org/docs/current/rangetypes.html for details.
212+
// Scanning one-dimensional range is supported using [MultiRange]
213+
type MultiRange[T any] []Range[T]
214+
215+
var (
216+
_ sql.Scanner = (*MultiRange[any])(nil)
217+
_ driver.Valuer = (*MultiRange[any])(nil)
218+
)
219+
220+
func (m *MultiRange[T]) Scan(anySrc any) error {
221+
var src []byte
222+
switch s := anySrc.(type) {
223+
case string:
224+
src = []byte(s)
225+
case []byte:
226+
src = s
227+
default:
228+
return fmt.Errorf("pq: cannot convert %T to MultiRange", anySrc)
229+
}
230+
231+
src = bytes.TrimSpace(src)
232+
if len(src) == 0 {
233+
return fmt.Errorf("pq: could not parse multirange: multirange is empty")
234+
}
235+
236+
if src[0] != '{' || src[len(src)-1] != '}' {
237+
return fmt.Errorf("pq: invalid multirange format: missing braces")
238+
}
239+
src = src[1 : len(src)-1]
240+
241+
blockPos := 0
242+
boundDepth := 0
243+
inQuote := false
244+
isEscaping := false
245+
for i, c := range src {
246+
if isEscaping {
247+
isEscaping = false
248+
continue
249+
}
250+
251+
switch c {
252+
case '\\':
253+
isEscaping = true
254+
case '"':
255+
inQuote = !inQuote
256+
case byte(RangeLowerBoundInclusive), byte(RangeLowerBoundExclusive):
257+
if !inQuote {
258+
boundDepth++
259+
}
260+
case byte(RangeUpperBoundInclusive), byte(RangeUpperBoundExclusive):
261+
if !inQuote {
262+
boundDepth--
263+
}
264+
case ',':
265+
if !inQuote && boundDepth == 0 {
266+
var r Range[T]
267+
if err := r.Scan(bytes.TrimSpace(src[blockPos:i])); err != nil {
268+
return err
269+
}
270+
*m = append(*m, r)
271+
blockPos = i + 1
272+
}
273+
}
274+
}
275+
276+
// parse last range if any
277+
if blockPos < len(src) {
278+
rBytes := bytes.TrimSpace(src[blockPos:])
279+
var r Range[T]
280+
if err := r.Scan(rBytes); err != nil {
281+
return err
282+
}
283+
*m = append(*m, r)
284+
}
285+
286+
return nil
287+
}
288+
289+
func (m MultiRange[T]) Value() (driver.Value, error) {
290+
if len(m) == 0 {
291+
return "{}", nil
292+
}
293+
var b []byte
294+
var err error
295+
b = append(b, '{')
296+
297+
for _, r := range m {
298+
iv, err := driver.DefaultParameterConverter.ConvertValue(r)
299+
if err != nil {
300+
return nil, err
301+
}
302+
303+
b, err = appendValue(b, iv)
304+
if err != nil {
305+
return nil, err
306+
}
307+
b = append(b, ',')
308+
}
309+
b = append(b[:len(b)-1], '}')
310+
return b, err
311+
}

0 commit comments

Comments
 (0)