Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
package bleve

import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"regexp"
"slices"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -830,3 +833,22 @@ func ParseParams(r *SearchRequest, input []byte) (*RequestParams, error) {

return params, nil
}

// OptionalRawMessage is a wrapper around json.RawMessage that treats empty or `null` JSON as nil.
type OptionalRawMessage json.RawMessage

func (n *OptionalRawMessage) UnmarshalJSON(data []byte) error {
if len(data) == 0 || bytes.Equal(data, []byte("null")) {
*n = nil
return nil
}
*n = slices.Clone(data)
return nil
}

func (n OptionalRawMessage) MarshalJSON() ([]byte, error) {
if len(n) == 0 {
return []byte("null"), nil
}
return n, nil
}
76 changes: 38 additions & 38 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/blevesearch/bleve/v2/search"
"github.com/blevesearch/bleve/v2/search/collector"
"github.com/blevesearch/bleve/v2/search/query"
"github.com/blevesearch/bleve/v2/util"
index "github.com/blevesearch/bleve_index_api"
)

Expand All @@ -42,18 +43,18 @@ type SearchRequest struct {
Query query.Query `json:"query"`
Size int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Highlight *HighlightRequest `json:"highlight,omitempty"`
Fields []string `json:"fields,omitempty"`
Facets FacetsRequest `json:"facets,omitempty"`
Explain bool `json:"explain"`
Sort search.SortOrder `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score,omitempty"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
SearchAfter []string `json:"search_after,omitempty"`
SearchBefore []string `json:"search_before,omitempty"`

KNN []*KNNRequest `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`
KNN []*KNNRequest `json:"knn,omitempty"`
KNNOperator knnOperator `json:"knn_operator,omitempty"`
Comment thread
CascadingRadium marked this conversation as resolved.

// PreSearchData will be a map that will be used
// in the second phase of any 2-phase search, to provide additional
Expand Down Expand Up @@ -125,35 +126,35 @@ func (r *SearchRequest) AddKNNOperator(operator knnOperator) {
// a SearchRequest
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
type tempKNNReq struct {
Field string `json:"field"`
Vector []float32 `json:"vector"`
VectorBase64 string `json:"vector_base64"`
K int64 `json:"k"`
Boost *query.Boost `json:"boost,omitempty"`
Params json.RawMessage `json:"params"`
FilterQuery json.RawMessage `json:"filter,omitempty"`
Field string `json:"field"`
Vector []float32 `json:"vector"`
VectorBase64 string `json:"vector_base64"`
K int64 `json:"k"`
Boost *query.Boost `json:"boost,omitempty"`
Params OptionalRawMessage `json:"params"`
FilterQuery OptionalRawMessage `json:"filter,omitempty"`
}

var temp struct {
Q json.RawMessage `json:"query"`
Size *int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Explain bool `json:"explain"`
Sort []json.RawMessage `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
KNN []*tempKNNReq `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`
PreSearchData json.RawMessage `json:"pre_search_data"`
Params json.RawMessage `json:"params"`
}

err := json.Unmarshal(input, &temp)
Q json.RawMessage `json:"query"`
Size *int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Explain bool `json:"explain"`
Sort []json.RawMessage `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
KNN []*tempKNNReq `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`
PreSearchData OptionalRawMessage `json:"pre_search_data"`
Params OptionalRawMessage `json:"params"`
}

err := util.UnmarshalJSON(input, &temp)
if err != nil {
return err
}
Expand Down Expand Up @@ -216,11 +217,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
r.KNN[i].VectorBase64 = temp.KNN[i].VectorBase64
r.KNN[i].K = temp.KNN[i].K
r.KNN[i].Boost = temp.KNN[i].Boost
r.KNN[i].Params = temp.KNN[i].Params
if len(knnReq.FilterQuery) == 0 {
// Setting this to nil to avoid ParseQuery() setting it to a match none
r.KNN[i].FilterQuery = nil
} else {
if len(temp.KNN[i].Params) > 0 {
r.KNN[i].Params = json.RawMessage(temp.KNN[i].Params)
}
if len(temp.KNN[i].FilterQuery) > 0 {
r.KNN[i].FilterQuery, err = query.ParseQuery(knnReq.FilterQuery)
if err != nil {
return err
Expand Down
50 changes: 50 additions & 0 deletions search_knn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2936,3 +2936,53 @@ func TestHierarchicalNestedVectorSearch(t *testing.T) {
}
}
}

func TestKNNNullParams(t *testing.T) {
queries := []struct {
query []byte
}{
{query: []byte(`{"knn": [{"field": "emb", "vector": [1, 2], "k": 3}]}`)},
{query: []byte(`{"knn": [{"field": "emb", "params": null, "vector": [1, 2], "k": 3}]}`)},
{query: []byte(`{"knn": [{"field": "emb","vector": [1, 2], "k": 3, "filter": null}]}`)},
{query: []byte(`{"pre_search_data": null, "knn": [{"field": "emb", "vector": [1, 2], "k": 3}]}`)},
}
Comment thread
CascadingRadium marked this conversation as resolved.

for _, q := range queries {
var searchReq SearchRequest
err := json.Unmarshal(q.query, &searchReq)
if err != nil {
t.Fatalf("failed to parse query: %v", err)
}
if len(searchReq.PreSearchData) > 0 {
t.Fatalf("expected no pre_search_data for query: %s, got %v", q.query, searchReq.PreSearchData)
}
for _, req := range searchReq.KNN {
if len(req.Params) > 0 {
t.Fatalf("expected no params for query: %s, got %v", q.query, req.Params)
}
if req.FilterQuery != nil {
t.Fatalf("expected no filter for query: %s, got %v", q.query, req.FilterQuery)
}
}
marshalled, err := json.Marshal(searchReq)
if err != nil {
t.Fatalf("failed to marshal search request: %v", err)
}
var unmarshalled SearchRequest
err = json.Unmarshal(marshalled, &unmarshalled)
if err != nil {
t.Fatalf("failed to unmarshal marshalled search request: %v", err)
}
if len(unmarshalled.PreSearchData) > 0 {
t.Fatalf("expected no pre_search_data after marshal/unmarshal for query: %s, got %v", q.query, unmarshalled.PreSearchData)
}
for _, req := range unmarshalled.KNN {
if len(req.Params) > 0 {
t.Fatalf("expected no params after marshal/unmarshal for query: %s, got %v", q.query, req.Params)
}
if req.FilterQuery != nil {
t.Fatalf("expected no filter after marshal/unmarshal for query: %s, got %v", q.query, req.FilterQuery)
}
}
}
}
41 changes: 21 additions & 20 deletions search_no_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/blevesearch/bleve/v2/search"
"github.com/blevesearch/bleve/v2/search/collector"
"github.com/blevesearch/bleve/v2/search/query"
"github.com/blevesearch/bleve/v2/util"
index "github.com/blevesearch/bleve_index_api"
)

Expand Down Expand Up @@ -55,15 +56,15 @@ type SearchRequest struct {
Query query.Query `json:"query"`
Size int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Highlight *HighlightRequest `json:"highlight,omitempty"`
Fields []string `json:"fields,omitempty"`
Facets FacetsRequest `json:"facets,omitempty"`
Explain bool `json:"explain"`
Sort search.SortOrder `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score,omitempty"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
SearchAfter []string `json:"search_after,omitempty"`
SearchBefore []string `json:"search_before,omitempty"`
Comment thread
CascadingRadium marked this conversation as resolved.

// PreSearchData will be a map that will be used
// in the second phase of any 2-phase search, to provide additional
Expand All @@ -86,23 +87,23 @@ type SearchRequest struct {
// a SearchRequest
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
var temp struct {
Q json.RawMessage `json:"query"`
Size *int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Explain bool `json:"explain"`
Sort []json.RawMessage `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
PreSearchData json.RawMessage `json:"pre_search_data"`
Params json.RawMessage `json:"params"`
Q json.RawMessage `json:"query"`
Size *int `json:"size"`
From int `json:"from"`
Highlight *HighlightRequest `json:"highlight"`
Fields []string `json:"fields"`
Facets FacetsRequest `json:"facets"`
Explain bool `json:"explain"`
Sort []json.RawMessage `json:"sort"`
IncludeLocations bool `json:"includeLocations"`
Score string `json:"score"`
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
PreSearchData OptionalRawMessage `json:"pre_search_data"`
Params OptionalRawMessage `json:"params"`
}

err := json.Unmarshal(input, &temp)
err := util.UnmarshalJSON(input, &temp)
if err != nil {
return err
}
Expand Down
Loading