Skip to content

Commit 4051279

Browse files
Fix null issue when parsing search request attributes (#2312)
- We allow parsing a few attributes of the search request as `json.RawMessage` like - `KNN Params` - `KNN Filter` - `PresearchData` - It is possible for the JSON body to contain the keyword `null` for these attributes, especially after a marshal-unmarshal cycle. - This `null` value is passed on downstream, causing every check for `len(attribute) > 0` to succeed and take the path for which the attribute is valid even if the attribute is actually `nil`. - Fix this issue by adding a custom type that parses `null` specially and sets the attribute to `nil` when required. --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 0afa561 commit 4051279

4 files changed

Lines changed: 131 additions & 58 deletions

File tree

search.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
package bleve
1616

1717
import (
18+
"bytes"
19+
"encoding/json"
1820
"fmt"
1921
"reflect"
2022
"regexp"
23+
"slices"
2124
"sort"
2225
"strconv"
2326
"strings"
@@ -830,3 +833,22 @@ func ParseParams(r *SearchRequest, input []byte) (*RequestParams, error) {
830833

831834
return params, nil
832835
}
836+
837+
// OptionalRawMessage is a wrapper around json.RawMessage that treats empty or `null` JSON as nil.
838+
type OptionalRawMessage json.RawMessage
839+
840+
func (n *OptionalRawMessage) UnmarshalJSON(data []byte) error {
841+
if len(data) == 0 || bytes.Equal(data, []byte("null")) {
842+
*n = nil
843+
return nil
844+
}
845+
*n = slices.Clone(data)
846+
return nil
847+
}
848+
849+
func (n OptionalRawMessage) MarshalJSON() ([]byte, error) {
850+
if len(n) == 0 {
851+
return []byte("null"), nil
852+
}
853+
return n, nil
854+
}

search_knn.go

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/blevesearch/bleve/v2/search"
2828
"github.com/blevesearch/bleve/v2/search/collector"
2929
"github.com/blevesearch/bleve/v2/search/query"
30+
"github.com/blevesearch/bleve/v2/util"
3031
index "github.com/blevesearch/bleve_index_api"
3132
)
3233

@@ -42,18 +43,18 @@ type SearchRequest struct {
4243
Query query.Query `json:"query"`
4344
Size int `json:"size"`
4445
From int `json:"from"`
45-
Highlight *HighlightRequest `json:"highlight"`
46-
Fields []string `json:"fields"`
47-
Facets FacetsRequest `json:"facets"`
46+
Highlight *HighlightRequest `json:"highlight,omitempty"`
47+
Fields []string `json:"fields,omitempty"`
48+
Facets FacetsRequest `json:"facets,omitempty"`
4849
Explain bool `json:"explain"`
4950
Sort search.SortOrder `json:"sort"`
5051
IncludeLocations bool `json:"includeLocations"`
5152
Score string `json:"score,omitempty"`
52-
SearchAfter []string `json:"search_after"`
53-
SearchBefore []string `json:"search_before"`
53+
SearchAfter []string `json:"search_after,omitempty"`
54+
SearchBefore []string `json:"search_before,omitempty"`
5455

55-
KNN []*KNNRequest `json:"knn"`
56-
KNNOperator knnOperator `json:"knn_operator"`
56+
KNN []*KNNRequest `json:"knn,omitempty"`
57+
KNNOperator knnOperator `json:"knn_operator,omitempty"`
5758

5859
// PreSearchData will be a map that will be used
5960
// in the second phase of any 2-phase search, to provide additional
@@ -125,35 +126,35 @@ func (r *SearchRequest) AddKNNOperator(operator knnOperator) {
125126
// a SearchRequest
126127
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
127128
type tempKNNReq struct {
128-
Field string `json:"field"`
129-
Vector []float32 `json:"vector"`
130-
VectorBase64 string `json:"vector_base64"`
131-
K int64 `json:"k"`
132-
Boost *query.Boost `json:"boost,omitempty"`
133-
Params json.RawMessage `json:"params"`
134-
FilterQuery json.RawMessage `json:"filter,omitempty"`
129+
Field string `json:"field"`
130+
Vector []float32 `json:"vector"`
131+
VectorBase64 string `json:"vector_base64"`
132+
K int64 `json:"k"`
133+
Boost *query.Boost `json:"boost,omitempty"`
134+
Params OptionalRawMessage `json:"params"`
135+
FilterQuery OptionalRawMessage `json:"filter,omitempty"`
135136
}
136137

137138
var temp struct {
138-
Q json.RawMessage `json:"query"`
139-
Size *int `json:"size"`
140-
From int `json:"from"`
141-
Highlight *HighlightRequest `json:"highlight"`
142-
Fields []string `json:"fields"`
143-
Facets FacetsRequest `json:"facets"`
144-
Explain bool `json:"explain"`
145-
Sort []json.RawMessage `json:"sort"`
146-
IncludeLocations bool `json:"includeLocations"`
147-
Score string `json:"score"`
148-
SearchAfter []string `json:"search_after"`
149-
SearchBefore []string `json:"search_before"`
150-
KNN []*tempKNNReq `json:"knn"`
151-
KNNOperator knnOperator `json:"knn_operator"`
152-
PreSearchData json.RawMessage `json:"pre_search_data"`
153-
Params json.RawMessage `json:"params"`
154-
}
155-
156-
err := json.Unmarshal(input, &temp)
139+
Q json.RawMessage `json:"query"`
140+
Size *int `json:"size"`
141+
From int `json:"from"`
142+
Highlight *HighlightRequest `json:"highlight"`
143+
Fields []string `json:"fields"`
144+
Facets FacetsRequest `json:"facets"`
145+
Explain bool `json:"explain"`
146+
Sort []json.RawMessage `json:"sort"`
147+
IncludeLocations bool `json:"includeLocations"`
148+
Score string `json:"score"`
149+
SearchAfter []string `json:"search_after"`
150+
SearchBefore []string `json:"search_before"`
151+
KNN []*tempKNNReq `json:"knn"`
152+
KNNOperator knnOperator `json:"knn_operator"`
153+
PreSearchData OptionalRawMessage `json:"pre_search_data"`
154+
Params OptionalRawMessage `json:"params"`
155+
}
156+
157+
err := util.UnmarshalJSON(input, &temp)
157158
if err != nil {
158159
return err
159160
}
@@ -216,11 +217,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
216217
r.KNN[i].VectorBase64 = temp.KNN[i].VectorBase64
217218
r.KNN[i].K = temp.KNN[i].K
218219
r.KNN[i].Boost = temp.KNN[i].Boost
219-
r.KNN[i].Params = temp.KNN[i].Params
220-
if len(knnReq.FilterQuery) == 0 {
221-
// Setting this to nil to avoid ParseQuery() setting it to a match none
222-
r.KNN[i].FilterQuery = nil
223-
} else {
220+
if len(temp.KNN[i].Params) > 0 {
221+
r.KNN[i].Params = json.RawMessage(temp.KNN[i].Params)
222+
}
223+
if len(temp.KNN[i].FilterQuery) > 0 {
224224
r.KNN[i].FilterQuery, err = query.ParseQuery(knnReq.FilterQuery)
225225
if err != nil {
226226
return err

search_knn_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2936,3 +2936,53 @@ func TestHierarchicalNestedVectorSearch(t *testing.T) {
29362936
}
29372937
}
29382938
}
2939+
2940+
func TestKNNNullParams(t *testing.T) {
2941+
queries := []struct {
2942+
query []byte
2943+
}{
2944+
{query: []byte(`{"knn": [{"field": "emb", "vector": [1, 2], "k": 3}]}`)},
2945+
{query: []byte(`{"knn": [{"field": "emb", "params": null, "vector": [1, 2], "k": 3}]}`)},
2946+
{query: []byte(`{"knn": [{"field": "emb","vector": [1, 2], "k": 3, "filter": null}]}`)},
2947+
{query: []byte(`{"pre_search_data": null, "knn": [{"field": "emb", "vector": [1, 2], "k": 3}]}`)},
2948+
}
2949+
2950+
for _, q := range queries {
2951+
var searchReq SearchRequest
2952+
err := json.Unmarshal(q.query, &searchReq)
2953+
if err != nil {
2954+
t.Fatalf("failed to parse query: %v", err)
2955+
}
2956+
if len(searchReq.PreSearchData) > 0 {
2957+
t.Fatalf("expected no pre_search_data for query: %s, got %v", q.query, searchReq.PreSearchData)
2958+
}
2959+
for _, req := range searchReq.KNN {
2960+
if len(req.Params) > 0 {
2961+
t.Fatalf("expected no params for query: %s, got %v", q.query, req.Params)
2962+
}
2963+
if req.FilterQuery != nil {
2964+
t.Fatalf("expected no filter for query: %s, got %v", q.query, req.FilterQuery)
2965+
}
2966+
}
2967+
marshalled, err := json.Marshal(searchReq)
2968+
if err != nil {
2969+
t.Fatalf("failed to marshal search request: %v", err)
2970+
}
2971+
var unmarshalled SearchRequest
2972+
err = json.Unmarshal(marshalled, &unmarshalled)
2973+
if err != nil {
2974+
t.Fatalf("failed to unmarshal marshalled search request: %v", err)
2975+
}
2976+
if len(unmarshalled.PreSearchData) > 0 {
2977+
t.Fatalf("expected no pre_search_data after marshal/unmarshal for query: %s, got %v", q.query, unmarshalled.PreSearchData)
2978+
}
2979+
for _, req := range unmarshalled.KNN {
2980+
if len(req.Params) > 0 {
2981+
t.Fatalf("expected no params after marshal/unmarshal for query: %s, got %v", q.query, req.Params)
2982+
}
2983+
if req.FilterQuery != nil {
2984+
t.Fatalf("expected no filter after marshal/unmarshal for query: %s, got %v", q.query, req.FilterQuery)
2985+
}
2986+
}
2987+
}
2988+
}

search_no_knn.go

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/blevesearch/bleve/v2/search"
2626
"github.com/blevesearch/bleve/v2/search/collector"
2727
"github.com/blevesearch/bleve/v2/search/query"
28+
"github.com/blevesearch/bleve/v2/util"
2829
index "github.com/blevesearch/bleve_index_api"
2930
)
3031

@@ -55,15 +56,15 @@ type SearchRequest struct {
5556
Query query.Query `json:"query"`
5657
Size int `json:"size"`
5758
From int `json:"from"`
58-
Highlight *HighlightRequest `json:"highlight"`
59-
Fields []string `json:"fields"`
60-
Facets FacetsRequest `json:"facets"`
59+
Highlight *HighlightRequest `json:"highlight,omitempty"`
60+
Fields []string `json:"fields,omitempty"`
61+
Facets FacetsRequest `json:"facets,omitempty"`
6162
Explain bool `json:"explain"`
6263
Sort search.SortOrder `json:"sort"`
6364
IncludeLocations bool `json:"includeLocations"`
6465
Score string `json:"score,omitempty"`
65-
SearchAfter []string `json:"search_after"`
66-
SearchBefore []string `json:"search_before"`
66+
SearchAfter []string `json:"search_after,omitempty"`
67+
SearchBefore []string `json:"search_before,omitempty"`
6768

6869
// PreSearchData will be a map that will be used
6970
// in the second phase of any 2-phase search, to provide additional
@@ -86,23 +87,23 @@ type SearchRequest struct {
8687
// a SearchRequest
8788
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
8889
var temp struct {
89-
Q json.RawMessage `json:"query"`
90-
Size *int `json:"size"`
91-
From int `json:"from"`
92-
Highlight *HighlightRequest `json:"highlight"`
93-
Fields []string `json:"fields"`
94-
Facets FacetsRequest `json:"facets"`
95-
Explain bool `json:"explain"`
96-
Sort []json.RawMessage `json:"sort"`
97-
IncludeLocations bool `json:"includeLocations"`
98-
Score string `json:"score"`
99-
SearchAfter []string `json:"search_after"`
100-
SearchBefore []string `json:"search_before"`
101-
PreSearchData json.RawMessage `json:"pre_search_data"`
102-
Params json.RawMessage `json:"params"`
90+
Q json.RawMessage `json:"query"`
91+
Size *int `json:"size"`
92+
From int `json:"from"`
93+
Highlight *HighlightRequest `json:"highlight"`
94+
Fields []string `json:"fields"`
95+
Facets FacetsRequest `json:"facets"`
96+
Explain bool `json:"explain"`
97+
Sort []json.RawMessage `json:"sort"`
98+
IncludeLocations bool `json:"includeLocations"`
99+
Score string `json:"score"`
100+
SearchAfter []string `json:"search_after"`
101+
SearchBefore []string `json:"search_before"`
102+
PreSearchData OptionalRawMessage `json:"pre_search_data"`
103+
Params OptionalRawMessage `json:"params"`
103104
}
104105

105-
err := json.Unmarshal(input, &temp)
106+
err := util.UnmarshalJSON(input, &temp)
106107
if err != nil {
107108
return err
108109
}

0 commit comments

Comments
 (0)