Skip to content
250 changes: 250 additions & 0 deletions router-tests/operations/pql_manifest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/wundergraph/cosmo/router-tests/testutils"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.uber.org/zap/zapcore"
Expand All @@ -38,6 +40,18 @@ func getCDNRequests(t *testing.T, cdnURL string) []string {
return requests
}

// countDataPointsWithAttribute counts histogram data points that have the given attribute.
func countDataPointsWithAttribute(dataPoints []metricdata.HistogramDataPoint[float64], attr attribute.KeyValue) int {
count := 0
for _, dp := range dataPoints {
val, ok := dp.Attributes.Value(attr.Key)
if ok && val == attr.Value {
count++
}
}
return count
}

func TestPQLManifest(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -106,6 +120,47 @@ func TestPQLManifest(t *testing.T) {
})
})

t.Run("feature flag mux resolves persisted operation from manifest", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(manifestConfig),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
header.Add("X-Feature-Flag", "myff")

res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
OperationName: []byte(`"Employees"`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`),
Header: header,
})
require.NoError(t, err)
require.Equal(t, "myff", res.Response.Header.Get("X-Feature-Flag"))
require.Equal(t, expectedEmployeesBody, res.Body)
})
})

t.Run("feature flag mux rejects unknown hash when manifest is authoritative", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(manifestConfig),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("X-Feature-Flag", "myff")

res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "0000000000000000000000000000000000000000000000000000000000000000"}}`),
Header: header,
})
require.Equal(t, "myff", res.Response.Header.Get("X-Feature-Flag"))
require.Equal(t, persistedNotFoundResp, res.Body)
})
})

t.Run("no CDN requests for individual operations", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
Expand Down Expand Up @@ -574,6 +629,35 @@ func TestPQLManifest(t *testing.T) {
})
})

t.Run("manifest warmup serves first feature flag request from cache", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(manifestConfigWithWarmup),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
header.Add("X-Feature-Flag", "myff")

// The very first request through the feature flag mux should hit all caches
// because manifest warmup pre-processed operations for each mux independently.
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
OperationName: []byte(`"Employees"`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`),
Header: header,
})
require.NoError(t, err)
require.Equal(t, "myff", res.Response.Header.Get("X-Feature-Flag"))
require.Equal(t, expectedEmployeesBody, res.Body)
require.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader))
require.Equal(t, "HIT", res.Response.Header.Get(core.NormalizationCacheHeader))
require.Equal(t, "HIT", res.Response.Header.Get(core.VariablesNormalizationCacheHeader))
require.Equal(t, "HIT", res.Response.Header.Get(core.VariablesRemappingCacheHeader))
require.Equal(t, "HIT", res.Response.Header.Get(core.ExecutionPlanCacheHeader))
})
})

t.Run("manifest warmup cache hit is independent of client name", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
Expand Down Expand Up @@ -943,6 +1027,172 @@ func TestPQLManifest(t *testing.T) {
})
})

t.Run("manifest warmup emits correct feature flag attribute", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()

testenv.Run(t, &testenv.Config{
MetricReader: metricReader,
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(manifestConfigWithWarmup),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Collect metrics emitted by manifest warmup at startup (no requests needed).
rm := metricdata.ResourceMetrics{}
err := metricReader.Collect(context.Background(), &rm)
require.NoError(t, err)

metricScope := testutils.GetMetricScopeByName(rm.ScopeMetrics, "cosmo.router")
require.NotNil(t, metricScope)

m := testutils.GetMetricByName(metricScope, "router.graphql.operation.planning_time")
require.NotNil(t, m, "planning_time metric should be emitted during manifest warmup")

dataPoints := m.Data.(metricdata.Histogram[float64]).DataPoints
require.NotEmpty(t, dataPoints)

// The base mux warmup should emit data points with wg.feature_flag = ""
require.True(t,
testutils.HasHistogramDataPointWithAttribute(dataPoints, otel.WgFeatureFlag.String("")),
"expected planning_time data point with wg.feature_flag=\"\" for base mux warmup",
)

// The feature flag mux warmup should emit data points with wg.feature_flag = "myff"
require.True(t,
testutils.HasHistogramDataPointWithAttribute(dataPoints, otel.WgFeatureFlag.String("myff")),
"expected planning_time data point with wg.feature_flag=\"myff\" for feature flag mux warmup",
)
})
})

t.Run("manifest re-warm fires for feature flag muxes", func(t *testing.T) {
t.Parallel()

employeesHash := "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"
employeesQuery := "query Employees {\n employees {\n id\n }\n}"

manifestV1, _ := json.Marshal(map[string]interface{}{
"version": 1,
"revision": "rev-v1",
"generatedAt": "2024-01-01T00:00:00Z",
"operations": map[string]string{
employeesHash: employeesQuery,
},
})
// v2 adds a second operation — the re-warm should plan it for all muxes
typenameHash := "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"
manifestV2, _ := json.Marshal(map[string]interface{}{
"version": 1,
"revision": "rev-v2",
"generatedAt": "2024-01-02T00:00:00Z",
"operations": map[string]string{
employeesHash: employeesQuery,
typenameHash: "{__typename}",
},
})

var currentManifest atomic.Value
currentManifest.Store(manifestV1)

var manifestFetchCount atomic.Int32

cdnServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/operations/manifest.json") {
manifest := currentManifest.Load().([]byte)

var m struct {
Revision string `json:"revision"`
}
_ = json.Unmarshal(manifest, &m)

ifNoneMatch := r.Header.Get("If-None-Match")
if ifNoneMatch == `"`+m.Revision+`"` {
w.Header().Set("ETag", ifNoneMatch)
w.WriteHeader(http.StatusNotModified)
return
}

manifestFetchCount.Add(1)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("ETag", `"`+m.Revision+`"`)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(manifest)
return
}

w.WriteHeader(http.StatusNotFound)
}))
defer cdnServer.Close()

metricReader := metric.NewManualReader()

testenv.Run(t, &testenv.Config{
CdnSever: cdnServer,
MetricReader: metricReader,
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(config.PersistedOperationsConfig{
Manifest: config.PQLManifestConfig{
Enabled: true,
PollInterval: 100 * time.Millisecond,
PollJitter: 5 * time.Millisecond,
Warmup: config.PQLManifestWarmupConfig{
Enabled: true,
Workers: 4,
Timeout: 30 * time.Second,
},
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// 1. Count initial warmup planning_time data points for the FF mux
initialRM := metricdata.ResourceMetrics{}
err := metricReader.Collect(context.Background(), &initialRM)
require.NoError(t, err)

initialScope := testutils.GetMetricScopeByName(initialRM.ScopeMetrics, "cosmo.router")
require.NotNil(t, initialScope)
initialMetric := testutils.GetMetricByName(initialScope, "router.graphql.operation.planning_time")
require.NotNil(t, initialMetric)

initialFFDataPoints := countDataPointsWithAttribute(
initialMetric.Data.(metricdata.Histogram[float64]).DataPoints,
otel.WgFeatureFlag.String("myff"),
)

// 2. Swap to manifest v2 which adds {__typename}
currentManifest.Store(manifestV2)

// 3. Wait for the poller to pick up the new manifest
require.Eventually(t, func() bool {
return manifestFetchCount.Load() >= 2
}, 5*time.Second, 50*time.Millisecond)

// 4. Wait for the re-warm to produce new planning_time metrics for the FF mux.
// This proves the warmup callback ran for the FF mux — no request-based caching involved.
require.Eventually(t, func() bool {
rm := metricdata.ResourceMetrics{}
if collectErr := metricReader.Collect(context.Background(), &rm); collectErr != nil {
return false
}
scope := testutils.GetMetricScopeByName(rm.ScopeMetrics, "cosmo.router")
if scope == nil {
return false
}
m := testutils.GetMetricByName(scope, "router.graphql.operation.planning_time")
if m == nil {
return false
}
currentFFDataPoints := countDataPointsWithAttribute(
m.Data.(metricdata.Histogram[float64]).DataPoints,
otel.WgFeatureFlag.String("myff"),
)
return currentFFDataPoints > initialFFDataPoints
}, 5*time.Second, 100*time.Millisecond,
"feature flag mux should have emitted new planning_time metrics after manifest re-warm")
})
})

t.Run("fails to start when initial CDN manifest fetch fails", func(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 11 additions & 3 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type (
connector *grpcconnector.Connector
circuitBreakerManager *circuit.Manager
headerPropagation *HeaderPropagation
shutdownStarted chan struct{}
}
)

Expand Down Expand Up @@ -195,6 +196,7 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
},
storageProviders: &r.storageProviders,
headerPropagation: r.headerPropagation,
shutdownStarted: make(chan struct{}),
}

baseOtelAttributes := []attribute.KeyValue{
Expand Down Expand Up @@ -1529,9 +1531,10 @@ func (s *graphServer) buildGraphMux(
}

// Re-warm when the manifest is updated by the poller.
// The store handles coalescing — if a re-warm is already running or pending,
// new signals are dropped. The worker always reads the latest manifest.
pqlStore.SetOnUpdate(func() {
// Each mux registers its own listener on the shared store, so all muxes
// (base + feature flags) get re-warmed independently with their own coalescing.
// The listener is removed when shutdownStarted is closed at the start of Shutdown.
pqlStore.AddListener(s.shutdownStarted, func() {
rewarmConfig := &CacheWarmupConfig{
Log: s.logger,
Processor: manifestProcessor,
Expand Down Expand Up @@ -1929,6 +1932,10 @@ func (s *graphServer) wait(ctx context.Context) error {
// After all requests are done, it will shut down the metric store and runtime metrics.
// Shutdown does cancel the context after all non-hijacked requests such as WebSockets has been handled.
func (s *graphServer) Shutdown(ctx context.Context) error {
// Stop manifest warmup listeners immediately — the old execution config
// is being replaced, so any new warmup work would be wasted.
close(s.shutdownStarted)

// Cancel the context after the graceful shutdown is done
// to clean up resources like websocket connections, pools, etc.
defer s.cancelFunc()
Expand Down Expand Up @@ -1983,6 +1990,7 @@ func (s *graphServer) Shutdown(ctx context.Context) error {
}

// Shutdown all graphs muxes to release resources
// Each mux removes its own manifest listener during shutdown.
// e.g. planner cache
s.graphMuxListLock.Lock()
defer s.graphMuxListLock.Unlock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ func TestStorageFetcherPollingLifecycle(t *testing.T) {
poller := NewPoller(storageFetcher, store, 50*time.Millisecond, 1*time.Millisecond, zap.NewNop())

// Track store update callbacks.
done := make(chan struct{})
defer close(done)
var updateCallbackCount atomic.Int32
store.SetOnUpdate(func() {
store.AddListener(done, func() {
updateCallbackCount.Add(1)
})

Expand Down
Loading
Loading