Skip to content

Commit 4d4984d

Browse files
add KNN unit test
1 parent d4fdbbc commit 4d4984d

1 file changed

Lines changed: 244 additions & 0 deletions

File tree

search_knn_test.go

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"testing"
3434
"time"
3535

36+
"github.com/blevesearch/bleve/v2/analysis/analyzer/keyword"
3637
"github.com/blevesearch/bleve/v2/analysis/lang/en"
3738
"github.com/blevesearch/bleve/v2/index/scorch"
3839
"github.com/blevesearch/bleve/v2/mapping"
@@ -2416,3 +2417,246 @@ func TestIndexInsightsCentroidCardinalities(t *testing.T) {
24162417
}
24172418
}
24182419
}
2420+
2421+
func TestHierarchicalNestedVectorSearch(t *testing.T) {
2422+
tmpIndexPath := createTmpIndexPath(t)
2423+
defer cleanupTmpIndexPath(t, tmpIndexPath)
2424+
2425+
dataset := `
2426+
[
2427+
{
2428+
"id": "doc1",
2429+
"items": [
2430+
{
2431+
"description": "I like trains",
2432+
"embedding_vector": [
2433+
1,
2434+
0,
2435+
0
2436+
],
2437+
"type": "transport"
2438+
},
2439+
{
2440+
"description": "I love pizza",
2441+
"embedding_vector": [
2442+
0,
2443+
1,
2444+
0
2445+
],
2446+
"type": "food"
2447+
}
2448+
]
2449+
},
2450+
{
2451+
"id": "doc2",
2452+
"items": [
2453+
{
2454+
"description": "I go to school by bus",
2455+
"embedding_vector": [
2456+
0.9,
2457+
0.1,
2458+
0
2459+
],
2460+
"type": "transport"
2461+
},
2462+
{
2463+
"description": "Sushi is delicious",
2464+
"embedding_vector": [
2465+
0,
2466+
1,
2467+
0
2468+
],
2469+
"type": "food"
2470+
}
2471+
]
2472+
},
2473+
{
2474+
"id": "doc3",
2475+
"items": [
2476+
{
2477+
"description": "Hamburgers are tasty",
2478+
"embedding_vector": [
2479+
0,
2480+
0.8,
2481+
0.2
2482+
],
2483+
"type": "food"
2484+
},
2485+
{
2486+
"description": "I enjoy biking",
2487+
"embedding_vector": [
2488+
0.7,
2489+
0,
2490+
0.3
2491+
],
2492+
"type": "transport"
2493+
}
2494+
]
2495+
}
2496+
]`
2497+
var documents []map[string]interface{}
2498+
err := json.Unmarshal([]byte(dataset), &documents)
2499+
if err != nil {
2500+
t.Fatalf("failed to unmarshal dataset: %v", err)
2501+
}
2502+
indexMapping := NewIndexMapping()
2503+
vecFieldMapping := mapping.NewVectorFieldMapping()
2504+
vecFieldMapping.Dims = 3
2505+
vecFieldMapping.Similarity = index.CosineSimilarity
2506+
2507+
typeMapping := mapping.NewTextFieldMapping()
2508+
typeMapping.Analyzer = keyword.Name
2509+
2510+
descMapping := mapping.NewTextFieldMapping()
2511+
descMapping.Analyzer = en.AnalyzerName
2512+
2513+
// items is NOT nested
2514+
itemsMapping := mapping.NewDocumentMapping()
2515+
itemsMapping.AddFieldMappingsAt("embedding_vector", vecFieldMapping)
2516+
itemsMapping.AddFieldMappingsAt("type", typeMapping)
2517+
itemsMapping.AddFieldMappingsAt("description", descMapping)
2518+
2519+
indexMapping.DefaultMapping.AddSubDocumentMapping("items", itemsMapping)
2520+
idx, err := New(tmpIndexPath, indexMapping)
2521+
if err != nil {
2522+
t.Fatalf("failed to create index: %v", err)
2523+
}
2524+
defer func() {
2525+
if err := idx.Close(); err != nil {
2526+
t.Fatalf("failed to close index: %v", err)
2527+
}
2528+
}()
2529+
2530+
batch := idx.NewBatch()
2531+
for _, doc := range documents {
2532+
err := batch.Index(doc["id"].(string), doc)
2533+
if err != nil {
2534+
t.Fatalf("failed to index document %s: %v", doc["id"], err)
2535+
}
2536+
}
2537+
err = idx.Batch(batch)
2538+
if err != nil {
2539+
t.Fatalf("failed to batch index documents: %v", err)
2540+
}
2541+
2542+
// Plain vector search
2543+
searchReq := NewSearchRequest(query.NewMatchNoneQuery())
2544+
searchReq.AddKNN("items.embedding_vector", []float32{0, 1, 0}, 5, 1.0)
2545+
searchReq.SortBy([]string{"-_score", "_id"})
2546+
2547+
res, err := idx.Search(searchReq)
2548+
if err != nil {
2549+
t.Fatalf("failed to execute search: %v", err)
2550+
}
2551+
2552+
expectedOrder := []string{"doc1", "doc2", "doc3"}
2553+
expectedScores := []float64{1.0, 1.0, 0.970}
2554+
if len(res.Hits) != len(expectedOrder) {
2555+
t.Fatalf("expected %d hits, got %d", len(expectedOrder), len(res.Hits))
2556+
}
2557+
for i, expectedID := range expectedOrder {
2558+
if res.Hits[i].ID != expectedID {
2559+
t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expectedID, res.Hits[i].ID)
2560+
}
2561+
if math.Abs(res.Hits[i].Score-expectedScores[i]) > 0.01 {
2562+
t.Fatalf("at rank %d, expected score %.3f, got %.3f", i+1, expectedScores[i], res.Hits[i].Score)
2563+
}
2564+
}
2565+
2566+
// Filtered vector search - should match output of plain vector search in non-nested case
2567+
filterQuery := NewTermQuery("transport")
2568+
filterQuery.SetField("items.type")
2569+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
2570+
searchReq.AddKNNWithFilter("items.embedding_vector", []float32{0, 1, 0}, 5, 1.0, filterQuery)
2571+
searchReq.SortBy([]string{"-_score", "_id"})
2572+
res, err = idx.Search(searchReq)
2573+
if err != nil {
2574+
t.Fatalf("failed to execute filtered search: %v", err)
2575+
}
2576+
if len(res.Hits) != len(expectedOrder) {
2577+
t.Fatalf("expected %d hits, got %d", len(expectedOrder), len(res.Hits))
2578+
}
2579+
for i, expectedID := range expectedOrder {
2580+
if res.Hits[i].ID != expectedID {
2581+
t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expectedID, res.Hits[i].ID)
2582+
}
2583+
if math.Abs(res.Hits[i].Score-expectedScores[i]) > 0.01 {
2584+
t.Fatalf("at rank %d, expected score %.3f, got %.3f", i+1, expectedScores[i], res.Hits[i].Score)
2585+
}
2586+
}
2587+
2588+
// items IS nested
2589+
nestedItemsMapping := mapping.NewNestedDocumentMapping()
2590+
nestedItemsMapping.AddFieldMappingsAt("embedding_vector", vecFieldMapping)
2591+
nestedItemsMapping.AddFieldMappingsAt("type", typeMapping)
2592+
nestedItemsMapping.AddFieldMappingsAt("description", descMapping)
2593+
2594+
indexMappingNested := NewIndexMapping()
2595+
indexMappingNested.DefaultMapping.AddSubDocumentMapping("items", nestedItemsMapping)
2596+
idxNested, err := New(tmpIndexPath+"_nested", indexMappingNested)
2597+
if err != nil {
2598+
t.Fatalf("failed to create nested index: %v", err)
2599+
}
2600+
defer func() {
2601+
if err := idxNested.Close(); err != nil {
2602+
t.Fatalf("failed to close nested index: %v", err)
2603+
}
2604+
}()
2605+
2606+
batch = idxNested.NewBatch()
2607+
for _, doc := range documents {
2608+
err := batch.Index(doc["id"].(string), doc)
2609+
if err != nil {
2610+
t.Fatalf("failed to index document %s in nested index: %v", doc["id"], err)
2611+
}
2612+
}
2613+
err = idxNested.Batch(batch)
2614+
if err != nil {
2615+
t.Fatalf("failed to batch index documents in nested index: %v", err)
2616+
}
2617+
// Plain vector search on nested index
2618+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
2619+
searchReq.AddKNN("items.embedding_vector", []float32{0, 1, 0}, 5, 1.0)
2620+
searchReq.SortBy([]string{"-_score", "_id"})
2621+
2622+
res, err = idxNested.Search(searchReq)
2623+
if err != nil {
2624+
t.Fatalf("failed to execute search on nested index: %v", err)
2625+
}
2626+
// Exact same behavior as non-nested in this case
2627+
if len(res.Hits) != len(expectedOrder) {
2628+
t.Fatalf("expected %d hits, got %d", len(expectedOrder), len(res.Hits))
2629+
}
2630+
for i, expectedID := range expectedOrder {
2631+
if res.Hits[i].ID != expectedID {
2632+
t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expectedID, res.Hits[i].ID)
2633+
}
2634+
if math.Abs(res.Hits[i].Score-expectedScores[i]) > 0.01 {
2635+
t.Fatalf("at rank %d, expected score %.3f, got %.3f", i+1, expectedScores[i], res.Hits[i].Score)
2636+
}
2637+
}
2638+
2639+
// Filtered vector search on nested index - should NOT match output of plain vector search in nested case
2640+
filterQuery = NewTermQuery("transport")
2641+
filterQuery.SetField("items.type")
2642+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
2643+
searchReq.AddKNNWithFilter("items.embedding_vector", []float32{0, 1, 0}, 5, 1.0, filterQuery)
2644+
searchReq.SortBy([]string{"-_score", "_id"})
2645+
res, err = idxNested.Search(searchReq)
2646+
if err != nil {
2647+
t.Fatalf("failed to execute filtered search on nested index: %v", err)
2648+
}
2649+
expectedNestedOrder := []string{"doc2", "doc1", "doc3"}
2650+
expectedNestedScores := []float64{0.110, 0, 0}
2651+
if len(res.Hits) != len(expectedNestedOrder) {
2652+
t.Fatalf("expected %d hits, got %d", len(expectedNestedOrder), len(res.Hits))
2653+
}
2654+
for i, expectedID := range expectedNestedOrder {
2655+
if res.Hits[i].ID != expectedID {
2656+
t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expectedID, res.Hits[i].ID)
2657+
}
2658+
if math.Abs(res.Hits[i].Score-expectedNestedScores[i]) > 0.01 {
2659+
t.Fatalf("at rank %d, expected score %.3f, got %.3f", i+1, expectedNestedScores[i], res.Hits[i].Score)
2660+
}
2661+
}
2662+
}

0 commit comments

Comments
 (0)