|
| 1 | +//go:build go1.18 |
| 2 | +// +build go1.18 |
| 3 | + |
| 4 | +package pq |
| 5 | + |
| 6 | +import ( |
| 7 | + "bytes" |
| 8 | + "database/sql" |
| 9 | + "database/sql/driver" |
| 10 | + "fmt" |
| 11 | + "reflect" |
| 12 | + "time" |
| 13 | + _ "unsafe" |
| 14 | +) |
| 15 | + |
| 16 | +// Range returns the optimal driver.Valuer and sql.Scanner for a range. |
| 17 | +// Check https://www.postgresql.org/docs/current/rangetypes.html for details. |
| 18 | +// |
| 19 | +// For example: |
| 20 | +// |
| 21 | +// min := time.Now() |
| 22 | +// max := time.Now().AddDate(0,-1,0) |
| 23 | +// db.Query(`SELECT * FROM reservation WHERE during && $1`, pq.NewRange(&min, &max)) |
| 24 | +// |
| 25 | +// var x pq.Range[int] |
| 26 | +// db.QueryRow(`SELECT '[1, 10)'`).Scan(&x) |
| 27 | +// |
| 28 | +// Scanning multi-dimensional range is supported using [MultiRange] |
| 29 | +type Range[T any] struct { |
| 30 | + // When nil it will be infinite value |
| 31 | + Lower, Upper *T |
| 32 | + LowerBound RangeLowerBound |
| 33 | + UpperBound RangeUpperBound |
| 34 | +} |
| 35 | + |
| 36 | +type RangeLowerBound byte |
| 37 | +type RangeUpperBound byte |
| 38 | + |
| 39 | +const ( |
| 40 | + RangeLowerBoundInclusive RangeLowerBound = '[' |
| 41 | + RangeLowerBoundExclusive RangeLowerBound = '(' |
| 42 | + RangeUpperBoundInclusive RangeUpperBound = ']' |
| 43 | + RangeUpperBoundExclusive RangeUpperBound = ')' |
| 44 | + |
| 45 | + RangeLowerBoundDefault = RangeLowerBoundInclusive |
| 46 | + RangeUpperBoundDefault = RangeUpperBoundExclusive |
| 47 | + |
| 48 | + RangeEmpty = "empty" |
| 49 | +) |
| 50 | + |
| 51 | +// NewRange create a [Range] with default bounds [RangeLowerBoundDefault] and [RangeUpperBoundDefault]. |
| 52 | +func NewRange[T any](lower, upper *T) Range[T] { |
| 53 | + return Range[T]{ |
| 54 | + Lower: lower, |
| 55 | + Upper: upper, |
| 56 | + LowerBound: RangeLowerBoundDefault, |
| 57 | + UpperBound: RangeUpperBoundDefault, |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +var ( |
| 62 | + _ sql.Scanner = (*Range[any])(nil) |
| 63 | + _ driver.Valuer = (*Range[any])(nil) |
| 64 | +) |
| 65 | + |
| 66 | +func (r *Range[T]) Scan(anySrc any) error { |
| 67 | + var src []byte |
| 68 | + switch s := anySrc.(type) { |
| 69 | + case string: |
| 70 | + src = []byte(s) |
| 71 | + case []byte: |
| 72 | + src = s |
| 73 | + default: |
| 74 | + return fmt.Errorf("pq: cannot convert %T to Range", anySrc) |
| 75 | + } |
| 76 | + |
| 77 | + src = bytes.TrimSpace(src) |
| 78 | + if len(src) == 0 { |
| 79 | + return fmt.Errorf("pq: could not parse range: range is empty") |
| 80 | + } |
| 81 | + |
| 82 | + if string(src) == RangeEmpty { |
| 83 | + return nil |
| 84 | + } |
| 85 | + |
| 86 | + // read bounds |
| 87 | + r.LowerBound = RangeLowerBound(src[0]) |
| 88 | + r.UpperBound = RangeUpperBound(src[len(src)-1]) |
| 89 | + src = src[1 : len(src)-1] |
| 90 | + if len(src) == 0 { |
| 91 | + return fmt.Errorf("pq: could not parse range: range is empty") |
| 92 | + } |
| 93 | + |
| 94 | + // read range |
| 95 | + l, u, ok := bytes.Cut(src, []byte(",")) |
| 96 | + if !ok { |
| 97 | + return fmt.Errorf("pq: could not parse range: missing comma") |
| 98 | + } |
| 99 | + |
| 100 | + convertBound := func(dest any, src []byte) error { |
| 101 | + src = bytes.Trim(src, "\"") |
| 102 | + switch d := dest.(type) { |
| 103 | + case sql.Scanner: |
| 104 | + if err := d.Scan(src); err != nil { |
| 105 | + return err |
| 106 | + } |
| 107 | + return nil |
| 108 | + case *time.Time: |
| 109 | + var err error |
| 110 | + *d, err = ParseTimestamp(nil, string(src)) |
| 111 | + if err != nil { |
| 112 | + return err |
| 113 | + } |
| 114 | + return nil |
| 115 | + } |
| 116 | + if err := convertAssign(dest, string(src)); err != nil { |
| 117 | + return err |
| 118 | + } |
| 119 | + return nil |
| 120 | + } |
| 121 | + |
| 122 | + if len(l) != 0 { |
| 123 | + r.Lower = new(T) |
| 124 | + if err := convertBound(r.Lower, l); err != nil { |
| 125 | + return err |
| 126 | + } |
| 127 | + } |
| 128 | + if len(u) != 0 { |
| 129 | + r.Upper = new(T) |
| 130 | + if err := convertBound(r.Upper, u); err != nil { |
| 131 | + return err |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + return nil |
| 136 | +} |
| 137 | + |
| 138 | +// IsEmpty return true when bounds are inclusive and range value are equal |
| 139 | +func (r Range[T]) IsEmpty() bool { |
| 140 | + if r.LowerBound == 0 && r.UpperBound == 0 { |
| 141 | + return true |
| 142 | + } |
| 143 | + if r.Lower == nil || r.Upper == nil { |
| 144 | + return false |
| 145 | + } |
| 146 | + if r.LowerBound == RangeLowerBoundExclusive || r.UpperBound == RangeUpperBoundExclusive { |
| 147 | + return false |
| 148 | + } |
| 149 | + return reflect.DeepEqual(*r.Lower, *r.Upper) |
| 150 | +} |
| 151 | + |
| 152 | +// IsZero return true when empty, used for IsZeroer interface |
| 153 | +func (r Range[T]) IsZero() bool { |
| 154 | + return r.IsEmpty() |
| 155 | +} |
| 156 | + |
| 157 | +func (r Range[T]) Value() (driver.Value, error) { |
| 158 | + if r.IsEmpty() { |
| 159 | + return RangeEmpty, nil |
| 160 | + } |
| 161 | + |
| 162 | + convertBound := func(src any) (string, error) { |
| 163 | + if reflect.ValueOf(src).IsNil() { |
| 164 | + return "", nil |
| 165 | + } |
| 166 | + |
| 167 | + switch s := src.(type) { |
| 168 | + case *time.Time: |
| 169 | + return "\"" + string(FormatTimestamp(*s)) + "\"", nil |
| 170 | + case driver.Valuer: |
| 171 | + v, err := s.Value() |
| 172 | + if err != nil { |
| 173 | + return "", err |
| 174 | + } |
| 175 | + var out string |
| 176 | + if err := convertAssign(&out, v); err != nil { |
| 177 | + return "", err |
| 178 | + } |
| 179 | + return "\"" + out + "\"", nil |
| 180 | + default: |
| 181 | + var out string |
| 182 | + if err := convertAssign(&out, reflect.ValueOf(src).Elem().Interface()); err != nil { |
| 183 | + return "", err |
| 184 | + } |
| 185 | + return "\"" + out + "\"", nil |
| 186 | + } |
| 187 | + } |
| 188 | + |
| 189 | + b, err := convertBound(r.Lower) |
| 190 | + if err != nil { |
| 191 | + return nil, err |
| 192 | + } |
| 193 | + |
| 194 | + var v string |
| 195 | + v += string(r.LowerBound) + b + "," |
| 196 | + |
| 197 | + b, err = convertBound(r.Upper) |
| 198 | + if err != nil { |
| 199 | + return nil, err |
| 200 | + } |
| 201 | + |
| 202 | + v += b + string(r.UpperBound) |
| 203 | + |
| 204 | + return v, nil |
| 205 | +} |
| 206 | + |
| 207 | +//go:linkname convertAssign database/sql.convertAssign |
| 208 | +func convertAssign(dest, src any) error |
| 209 | + |
| 210 | +// MultiRange returns the optimal driver.Valuer and sql.Scanner for a multirange. |
| 211 | +// Check https://www.postgresql.org/docs/current/rangetypes.html for details. |
| 212 | +// Scanning one-dimensional range is supported using [MultiRange] |
| 213 | +type MultiRange[T any] []Range[T] |
| 214 | + |
| 215 | +var ( |
| 216 | + _ sql.Scanner = (*MultiRange[any])(nil) |
| 217 | + _ driver.Valuer = (*MultiRange[any])(nil) |
| 218 | +) |
| 219 | + |
| 220 | +func (m *MultiRange[T]) Scan(anySrc any) error { |
| 221 | + var src []byte |
| 222 | + switch s := anySrc.(type) { |
| 223 | + case string: |
| 224 | + src = []byte(s) |
| 225 | + case []byte: |
| 226 | + src = s |
| 227 | + default: |
| 228 | + return fmt.Errorf("pq: cannot convert %T to MultiRange", anySrc) |
| 229 | + } |
| 230 | + |
| 231 | + src = bytes.TrimSpace(src) |
| 232 | + if len(src) == 0 { |
| 233 | + return fmt.Errorf("pq: could not parse multirange: multirange is empty") |
| 234 | + } |
| 235 | + |
| 236 | + if src[0] != '{' || src[len(src)-1] != '}' { |
| 237 | + return fmt.Errorf("pq: invalid multirange format: missing braces") |
| 238 | + } |
| 239 | + src = src[1 : len(src)-1] |
| 240 | + |
| 241 | + blockPos := 0 |
| 242 | + boundDepth := 0 |
| 243 | + inQuote := false |
| 244 | + isEscaping := false |
| 245 | + for i, c := range src { |
| 246 | + if isEscaping { |
| 247 | + isEscaping = false |
| 248 | + continue |
| 249 | + } |
| 250 | + |
| 251 | + switch c { |
| 252 | + case '\\': |
| 253 | + isEscaping = true |
| 254 | + case '"': |
| 255 | + inQuote = !inQuote |
| 256 | + case byte(RangeLowerBoundInclusive), byte(RangeLowerBoundExclusive): |
| 257 | + if !inQuote { |
| 258 | + boundDepth++ |
| 259 | + } |
| 260 | + case byte(RangeUpperBoundInclusive), byte(RangeUpperBoundExclusive): |
| 261 | + if !inQuote { |
| 262 | + boundDepth-- |
| 263 | + } |
| 264 | + case ',': |
| 265 | + if !inQuote && boundDepth == 0 { |
| 266 | + var r Range[T] |
| 267 | + if err := r.Scan(bytes.TrimSpace(src[blockPos:i])); err != nil { |
| 268 | + return err |
| 269 | + } |
| 270 | + *m = append(*m, r) |
| 271 | + blockPos = i + 1 |
| 272 | + } |
| 273 | + } |
| 274 | + } |
| 275 | + |
| 276 | + // parse last range if any |
| 277 | + if blockPos < len(src) { |
| 278 | + rBytes := bytes.TrimSpace(src[blockPos:]) |
| 279 | + var r Range[T] |
| 280 | + if err := r.Scan(rBytes); err != nil { |
| 281 | + return err |
| 282 | + } |
| 283 | + *m = append(*m, r) |
| 284 | + } |
| 285 | + |
| 286 | + return nil |
| 287 | +} |
| 288 | + |
| 289 | +func (m MultiRange[T]) Value() (driver.Value, error) { |
| 290 | + if len(m) == 0 { |
| 291 | + return "{}", nil |
| 292 | + } |
| 293 | + var b []byte |
| 294 | + var err error |
| 295 | + b = append(b, '{') |
| 296 | + |
| 297 | + for _, r := range m { |
| 298 | + iv, err := driver.DefaultParameterConverter.ConvertValue(r) |
| 299 | + if err != nil { |
| 300 | + return nil, err |
| 301 | + } |
| 302 | + |
| 303 | + b, err = appendValue(b, iv) |
| 304 | + if err != nil { |
| 305 | + return nil, err |
| 306 | + } |
| 307 | + b = append(b, ',') |
| 308 | + } |
| 309 | + b = append(b[:len(b)-1], '}') |
| 310 | + return b, err |
| 311 | +} |
0 commit comments