From 26c7b592713233436f1eb9d7bd9dc220efb85868 Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Wed, 28 Aug 2024 16:57:05 +0100 Subject: [PATCH 1/7] add parallel calls to call GetStats --- client/nginx.go | 393 ++++++++++++++++++++++++++++++------------- client/nginx_test.go | 5 + go.mod | 2 + go.sum | 2 + 4 files changed, 282 insertions(+), 120 deletions(-) diff --git a/client/nginx.go b/client/nginx.go index d5fff60f..021fc8a6 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -11,7 +11,10 @@ import ( "reflect" "slices" "strings" + "sync" "time" + + "golang.org/x/sync/errgroup" ) const ( @@ -118,6 +121,40 @@ func (internalError *internalError) Wrap(err string) *internalError { return internalError } +// this is an internal representation of the Stats object including endpoint and streamEndpoint lists +type extendedStats struct { + endpoints []string + streamEndpoints []string + Stats +} + +func defaultStats() *extendedStats { + return &extendedStats{ + endpoints: []string{}, + streamEndpoints: []string{}, + Stats: Stats{ + Upstreams: map[string]Upstream{}, + ServerZones: map[string]ServerZone{}, + StreamServerZones: map[string]StreamServerZone{}, + StreamUpstreams: map[string]StreamUpstream{}, + Slabs: map[string]Slab{}, + Caches: map[string]HTTPCache{}, + HTTPLimitConnections: map[string]LimitConnection{}, + StreamLimitConnections: map[string]LimitConnection{}, + HTTPLimitRequests: map[string]HTTPLimitRequest{}, + Resolvers: map[string]Resolver{}, + LocationZones: map[string]LocationZone{}, + StreamZoneSync: &StreamZoneSync{}, + Workers: []*Workers{}, + NginxInfo: NginxInfo{}, + SSL: SSL{}, + Connections: Connections{}, + HTTPRequests: HTTPRequests{}, + Processes: Processes{}, + }, + } +} + // Stats represents NGINX Plus stats fetched from the NGINX Plus API. // https://nginx.org/en/docs/http/ngx_http_api_module.html type Stats struct { @@ -1177,141 +1214,257 @@ func determineStreamUpdates(updatedServers []StreamUpstreamServer, nginxServers // GetStats gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. func (client *NginxClient) GetStats() (*Stats, error) { - endpoints, err := client.GetAvailableEndpoints() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - info, err := client.GetNginxInfo() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - caches, err := client.GetCaches() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - processes, err := client.GetProcesses() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - slabs, err := client.GetSlabs() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - cons, err := client.GetConnections() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - requests, err := client.GetHTTPRequests() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - ssl, err := client.GetSSL() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - zones, err := client.GetServerZones() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - upstreams, err := client.GetUpstreams() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - locationZones, err := client.GetLocationZones() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - resolvers, err := client.GetResolvers() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - limitReqs, err := client.GetHTTPLimitReqs() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } - - limitConnsHTTP, err := client.GetHTTPConnectionsLimit() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + var g errgroup.Group + var mu sync.Mutex + stats := defaultStats() + // Collecting initial stats + g.Go(func() error { + mu.Lock() + endpoints, err := client.GetAvailableEndpoints() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get available Endpoints: %w", err) + } + stats.endpoints = endpoints + return nil + }) + + g.Go(func() error { + mu.Lock() + nginxInfo, err := client.GetNginxInfo() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get NGINX info: %w", err) + } + stats.NginxInfo = *nginxInfo + return nil + }) + + g.Go(func() error { + mu.Lock() + caches, err := client.GetCaches() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Caches: %w", err) + } + stats.Caches = *caches + return nil + }) + + g.Go(func() error { + mu.Lock() + processes, err := client.GetProcesses() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Process information: %w", err) + } + stats.Processes = *processes + return nil + }) + + g.Go(func() error { + mu.Lock() + slabs, err := client.GetSlabs() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Slabs: %w", err) + } + stats.Slabs = *slabs + return nil + }) + + g.Go(func() error { + mu.Lock() + httpRequests, err := client.GetHTTPRequests() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get HTTP Requests: %w", err) + } + stats.HTTPRequests = *httpRequests + return nil + }) + + g.Go(func() error { + mu.Lock() + ssl, err := client.GetSSL() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get SSL: %w", err) + } + stats.SSL = *ssl + return nil + }) + + g.Go(func() error { + mu.Lock() + serverZones, err := client.GetServerZones() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Server Zones: %w", err) + } + stats.ServerZones = *serverZones + return nil + }) + + g.Go(func() error { + mu.Lock() + upstreams, err := client.GetUpstreams() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Upstreams: %w", err) + } + stats.Upstreams = *upstreams + return nil + }) + + g.Go(func() error { + mu.Lock() + locationZones, err := client.GetLocationZones() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Location Zones: %w", err) + } + stats.LocationZones = *locationZones + return nil + }) + + g.Go(func() error { + mu.Lock() + resolvers, err := client.GetResolvers() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Resolvers: %w", err) + } + stats.Resolvers = *resolvers + return nil + }) + + g.Go(func() error { + mu.Lock() + httpLimitRequests, err := client.GetHTTPLimitReqs() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get HTTPLimitRequests: %w", err) + } + stats.HTTPLimitRequests = *httpLimitRequests + return nil + }) + + g.Go(func() error { + mu.Lock() + httpLimitConnections, err := client.GetHTTPConnectionsLimit() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get HTTPLimitConnections: %w", err) + } + stats.HTTPLimitConnections = *httpLimitConnections + return nil + }) + + g.Go(func() error { + mu.Lock() + workers, err := client.GetWorkers() + mu.Unlock() + if err != nil { + return fmt.Errorf("failed to get Workers: %w", err) + } + stats.Workers = workers + return nil + }) - workers, err := client.GetWorkers() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("error returned from contacting Plus API: %w", err) } - streamZones := &StreamServerZones{} - streamUpstreams := &StreamUpstreams{} - limitConnsStream := &StreamLimitConnections{} - var streamZoneSync *StreamZoneSync - - if slices.Contains(endpoints, "stream") { - streamEndpoints, err := client.GetAvailableStreamEndpoints() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + // Process stream endpoints if they exist + if slices.Contains(stats.endpoints, "stream") { + var streamGroup errgroup.Group - if slices.Contains(streamEndpoints, "server_zones") { - streamZones, err = client.GetStreamServerZones() + streamGroup.Go(func() error { + mu.Lock() + streamEndpoints, err := client.GetAvailableStreamEndpoints() + mu.Unlock() if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) + return fmt.Errorf("failed to get available Stream Endpoints: %w", err) } + stats.streamEndpoints = streamEndpoints + return nil + }) + + if slices.Contains(stats.streamEndpoints, "server_zones") { + streamGroup.Go(func() error { + streamServerZones, err := client.GetStreamServerZones() + if err != nil { + return fmt.Errorf("failed to get streamServerZones: %w", err) + } + mu.Lock() + stats.StreamServerZones = *streamServerZones + mu.Unlock() + return nil + }) } - if slices.Contains(streamEndpoints, "upstreams") { - streamUpstreams, err = client.GetStreamUpstreams() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + if slices.Contains(stats.streamEndpoints, "upstreams") { + streamGroup.Go(func() error { + streamUpstreams, err := client.GetStreamUpstreams() + if err != nil { + return fmt.Errorf("failed to get StreamUpstreams: %w", err) + } + mu.Lock() + stats.StreamUpstreams = *streamUpstreams + mu.Unlock() + + return nil + }) } - if slices.Contains(streamEndpoints, "limit_conns") { - limitConnsStream, err = client.GetStreamConnectionsLimit() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + if slices.Contains(stats.streamEndpoints, "limit_conns") { + streamGroup.Go(func() error { + streamConnectionsLimit, err := client.GetStreamConnectionsLimit() + if err != nil { + return fmt.Errorf("failed to get StreamLimitConnections: %w", err) + } + mu.Lock() + stats.StreamLimitConnections = *streamConnectionsLimit + mu.Unlock() + return nil + }) + + streamGroup.Go(func() error { + streamZoneSync, err := client.GetStreamZoneSync() + if err != nil { + return fmt.Errorf("failed to get StreamZoneSync: %w", err) + } + mu.Lock() + stats.StreamZoneSync = streamZoneSync + mu.Unlock() + return nil + }) } - if slices.Contains(streamEndpoints, "zone_sync") { - streamZoneSync, err = client.GetStreamZoneSync() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + if err := streamGroup.Wait(); err != nil { + return nil, fmt.Errorf("no useful metrics found in stream stats: %w", err) } } - return &Stats{ - NginxInfo: *info, - Caches: *caches, - Processes: *processes, - Slabs: *slabs, - Connections: *cons, - HTTPRequests: *requests, - SSL: *ssl, - ServerZones: *zones, - StreamServerZones: *streamZones, - Upstreams: *upstreams, - StreamUpstreams: *streamUpstreams, - StreamZoneSync: streamZoneSync, - LocationZones: *locationZones, - Resolvers: *resolvers, - HTTPLimitRequests: *limitReqs, - HTTPLimitConnections: *limitConnsHTTP, - StreamLimitConnections: *limitConnsStream, - Workers: workers, - }, nil + // Report connection metrics separately so it does not influence the results + var connectionsGroup errgroup.Group + connectionsGroup.Go(func() error { + connections, err := client.GetConnections() + if err != nil { + return fmt.Errorf("failed to get connections: %w", err) + } + mu.Lock() + stats.Connections = *connections + mu.Unlock() + return nil + }) + + if err := connectionsGroup.Wait(); err != nil { + return nil, fmt.Errorf("connections metrics not found: %w", err) + } + + return &stats.Stats, nil } // GetAvailableEndpoints returns available endpoints in the API. diff --git a/client/nginx_test.go b/client/nginx_test.go index 2cb8a7e8..1a23b7ba 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -708,6 +708,11 @@ func TestGetStats_SSL(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + case strings.HasPrefix(r.RequestURI, "/8/stream"): + _, err := w.Write([]byte(`[""]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } default: _, err := w.Write([]byte(`{}`)) if err != nil { diff --git a/go.mod b/go.mod index 4fc51a45..c7a848ea 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/nginxinc/nginx-plus-go-client go 1.21.2 + +require golang.org/x/sync v0.8.0 diff --git a/go.sum b/go.sum index e69de29b..e584c1bd 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= From 858e54ff76b5577b4023acb4bf956fd927711fa2 Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Wed, 28 Aug 2024 17:07:40 +0100 Subject: [PATCH 2/7] lint issue --- client/nginx.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/nginx.go b/client/nginx.go index 314bc3b6..244ab98b 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -126,7 +126,7 @@ func (internalError *internalError) Wrap(err string) *internalError { return internalError } -// this is an internal representation of the Stats object including endpoint and streamEndpoint lists +// this is an internal representation of the Stats object including endpoint and streamEndpoint lists. type extendedStats struct { endpoints []string streamEndpoints []string From 2bb5e3c7bc63f38fe7643ec91075a2d038cd8c10 Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Wed, 28 Aug 2024 17:08:53 +0100 Subject: [PATCH 3/7] fixed unit test --- client/nginx.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/nginx.go b/client/nginx.go index 244ab98b..0c2bd225 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -149,7 +149,7 @@ func defaultStats() *extendedStats { HTTPLimitRequests: map[string]HTTPLimitRequest{}, Resolvers: map[string]Resolver{}, LocationZones: map[string]LocationZone{}, - StreamZoneSync: &StreamZoneSync{}, + StreamZoneSync: nil, Workers: []*Workers{}, NginxInfo: NginxInfo{}, SSL: SSL{}, From f9eb1cdc6241341980842634ce8b56520be9146a Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Thu, 29 Aug 2024 14:57:18 +0100 Subject: [PATCH 4/7] added context to everything in get stats --- client/nginx.go | 259 ++++++++++++++++++++++++++++++++----------- client/nginx_test.go | 110 +++++++++++++----- 2 files changed, 275 insertions(+), 94 deletions(-) diff --git a/client/nginx.go b/client/nginx.go index 0c2bd225..4388497d 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -927,7 +927,11 @@ func (client *NginxClient) getIDOfHTTPServer(upstream string, name string) (int, } func (client *NginxClient) get(path string, data interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + return client.getWithContext(context.Background(), path, data) +} + +func (client *NginxClient) getWithContext(ctx context.Context, path string, data interface{}) error { + ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) defer cancel() url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) @@ -961,7 +965,11 @@ func (client *NginxClient) get(path string, data interface{}) error { } func (client *NginxClient) post(path string, input interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + return client.postWithConext(context.Background(), path, input) +} + +func (client *NginxClient) postWithConext(ctx context.Context, path string, input interface{}) error { + ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) defer cancel() url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) @@ -993,7 +1001,11 @@ func (client *NginxClient) post(path string, input interface{}) error { } func (client *NginxClient) delete(path string, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + return client.deleteWithContext(context.Background(), path, expectedStatusCode) +} + +func (client *NginxClient) deleteWithContext(ctx context.Context, path string, expectedStatusCode int) error { + ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) defer cancel() path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) @@ -1018,7 +1030,11 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error { } func (client *NginxClient) patch(path string, input interface{}, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + return client.patchWithContext(context.Background(), path, input, expectedStatusCode) +} + +func (client *NginxClient) patchWithContext(ctx context.Context, path string, input interface{}, expectedStatusCode int) error { + ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) defer cancel() path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) @@ -1236,15 +1252,15 @@ func determineStreamUpdates(updatedServers []StreamUpstreamServer, nginxServers return } -// GetStats gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. -func (client *NginxClient) GetStats() (*Stats, error) { - var g errgroup.Group +// GetStatsWithContext gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. +func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, error) { + initialGroup, initialCtx := errgroup.WithContext(ctx) var mu sync.Mutex stats := defaultStats() // Collecting initial stats - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - endpoints, err := client.GetAvailableEndpoints() + endpoints, err := client.GetAvailableEndpointsWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get available Endpoints: %w", err) @@ -1253,9 +1269,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - nginxInfo, err := client.GetNginxInfo() + nginxInfo, err := client.GetNginxInfoWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get NGINX info: %w", err) @@ -1264,9 +1280,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - caches, err := client.GetCaches() + caches, err := client.GetCachesWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Caches: %w", err) @@ -1275,9 +1291,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - processes, err := client.GetProcesses() + processes, err := client.GetProcessesWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Process information: %w", err) @@ -1286,9 +1302,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - slabs, err := client.GetSlabs() + slabs, err := client.GetSlabsWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Slabs: %w", err) @@ -1297,9 +1313,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - httpRequests, err := client.GetHTTPRequests() + httpRequests, err := client.GetHTTPRequestsWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTP Requests: %w", err) @@ -1308,9 +1324,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - ssl, err := client.GetSSL() + ssl, err := client.GetSSLWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get SSL: %w", err) @@ -1319,9 +1335,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - serverZones, err := client.GetServerZones() + serverZones, err := client.GetServerZonesWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Server Zones: %w", err) @@ -1330,9 +1346,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - upstreams, err := client.GetUpstreams() + upstreams, err := client.GetUpstreamsWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Upstreams: %w", err) @@ -1341,9 +1357,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - locationZones, err := client.GetLocationZones() + locationZones, err := client.GetLocationZonesWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Location Zones: %w", err) @@ -1352,9 +1368,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - resolvers, err := client.GetResolvers() + resolvers, err := client.GetResolversWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Resolvers: %w", err) @@ -1363,9 +1379,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - httpLimitRequests, err := client.GetHTTPLimitReqs() + httpLimitRequests, err := client.GetHTTPLimitReqsWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTPLimitRequests: %w", err) @@ -1374,9 +1390,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - httpLimitConnections, err := client.GetHTTPConnectionsLimit() + httpLimitConnections, err := client.GetHTTPConnectionsLimitWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTPLimitConnections: %w", err) @@ -1385,9 +1401,9 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - g.Go(func() error { + initialGroup.Go(func() error { mu.Lock() - workers, err := client.GetWorkers() + workers, err := client.GetWorkersWithContext(initialCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get Workers: %w", err) @@ -1396,17 +1412,17 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) - if err := g.Wait(); err != nil { + if err := initialGroup.Wait(); err != nil { return nil, fmt.Errorf("error returned from contacting Plus API: %w", err) } // Process stream endpoints if they exist if slices.Contains(stats.endpoints, "stream") { - var streamGroup errgroup.Group + availableStreamGroup, asgCtx := errgroup.WithContext(ctx) - streamGroup.Go(func() error { + availableStreamGroup.Go(func() error { mu.Lock() - streamEndpoints, err := client.GetAvailableStreamEndpoints() + streamEndpoints, err := client.GetAvailableStreamEndpointsWithContext(asgCtx) mu.Unlock() if err != nil { return fmt.Errorf("failed to get available Stream Endpoints: %w", err) @@ -1415,9 +1431,15 @@ func (client *NginxClient) GetStats() (*Stats, error) { return nil }) + if err := availableStreamGroup.Wait(); err != nil { + return nil, fmt.Errorf("no useful metrics found in stream stats: %w", err) + } + + streamGroup, sgCtx := errgroup.WithContext(ctx) + if slices.Contains(stats.streamEndpoints, "server_zones") { streamGroup.Go(func() error { - streamServerZones, err := client.GetStreamServerZones() + streamServerZones, err := client.GetStreamServerZonesWithContext(sgCtx) if err != nil { return fmt.Errorf("failed to get streamServerZones: %w", err) } @@ -1430,7 +1452,7 @@ func (client *NginxClient) GetStats() (*Stats, error) { if slices.Contains(stats.streamEndpoints, "upstreams") { streamGroup.Go(func() error { - streamUpstreams, err := client.GetStreamUpstreams() + streamUpstreams, err := client.GetStreamUpstreamsWithContext(sgCtx) if err != nil { return fmt.Errorf("failed to get StreamUpstreams: %w", err) } @@ -1444,7 +1466,7 @@ func (client *NginxClient) GetStats() (*Stats, error) { if slices.Contains(stats.streamEndpoints, "limit_conns") { streamGroup.Go(func() error { - streamConnectionsLimit, err := client.GetStreamConnectionsLimit() + streamConnectionsLimit, err := client.GetStreamConnectionsLimitWithContext(sgCtx) if err != nil { return fmt.Errorf("failed to get StreamLimitConnections: %w", err) } @@ -1455,7 +1477,7 @@ func (client *NginxClient) GetStats() (*Stats, error) { }) streamGroup.Go(func() error { - streamZoneSync, err := client.GetStreamZoneSync() + streamZoneSync, err := client.GetStreamZoneSyncWithContext(sgCtx) if err != nil { return fmt.Errorf("failed to get StreamZoneSync: %w", err) } @@ -1472,9 +1494,11 @@ func (client *NginxClient) GetStats() (*Stats, error) { } // Report connection metrics separately so it does not influence the results - var connectionsGroup errgroup.Group + connectionsGroup, cgCtx := errgroup.WithContext(ctx) + connectionsGroup.Go(func() error { - connections, err := client.GetConnections() + // replace this call with a context specific call + connections, err := client.GetConnectionsWithContext(cgCtx) if err != nil { return fmt.Errorf("failed to get connections: %w", err) } @@ -1491,10 +1515,20 @@ func (client *NginxClient) GetStats() (*Stats, error) { return &stats.Stats, nil } +// GetStats gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. +func (client *NginxClient) GetStats() (*Stats, error) { + return client.GetStatsWithContext(context.Background()) +} + // GetAvailableEndpoints returns available endpoints in the API. func (client *NginxClient) GetAvailableEndpoints() ([]string, error) { + return client.GetAvailableEndpointsWithContext(context.Background()) +} + +// GetAvailableEndpointsWithContext returns available endpoints in the API. +func (client *NginxClient) GetAvailableEndpointsWithContext(ctx context.Context) ([]string, error) { var endpoints []string - err := client.get("", &endpoints) + err := client.getWithContext(ctx, "", &endpoints) if err != nil { return nil, fmt.Errorf("failed to get endpoints: %w", err) } @@ -1503,8 +1537,13 @@ func (client *NginxClient) GetAvailableEndpoints() ([]string, error) { // GetAvailableStreamEndpoints returns available stream endpoints in the API. func (client *NginxClient) GetAvailableStreamEndpoints() ([]string, error) { + return client.GetAvailableStreamEndpointsWithContext(context.Background()) +} + +// GetAvailableStreamEndpointsWithContext returns available stream endpoints in the API with a context. +func (client *NginxClient) GetAvailableStreamEndpointsWithContext(ctx context.Context) ([]string, error) { var endpoints []string - err := client.get("stream", &endpoints) + err := client.getWithContext(ctx, "stream", &endpoints) if err != nil { return nil, fmt.Errorf("failed to get endpoints: %w", err) } @@ -1513,8 +1552,13 @@ func (client *NginxClient) GetAvailableStreamEndpoints() ([]string, error) { // GetNginxInfo returns Nginx stats. func (client *NginxClient) GetNginxInfo() (*NginxInfo, error) { + return client.GetNginxInfoWithContext(context.Background()) +} + +// GetNginxInfoWithContext returns Nginx stats with a context. +func (client *NginxClient) GetNginxInfoWithContext(ctx context.Context) (*NginxInfo, error) { var info NginxInfo - err := client.get("nginx", &info) + err := client.getWithContext(ctx, "nginx", &info) if err != nil { return nil, fmt.Errorf("failed to get info: %w", err) } @@ -1523,8 +1567,13 @@ func (client *NginxClient) GetNginxInfo() (*NginxInfo, error) { // GetCaches returns Cache stats. func (client *NginxClient) GetCaches() (*Caches, error) { + return client.GetCachesWithContext(context.Background()) +} + +// GetCachesWithContext returns Cache stats with a context. +func (client *NginxClient) GetCachesWithContext(ctx context.Context) (*Caches, error) { var caches Caches - err := client.get("http/caches", &caches) + err := client.getWithContext(ctx, "http/caches", &caches) if err != nil { return nil, fmt.Errorf("failed to get caches: %w", err) } @@ -1533,8 +1582,13 @@ func (client *NginxClient) GetCaches() (*Caches, error) { // GetSlabs returns Slabs stats. func (client *NginxClient) GetSlabs() (*Slabs, error) { + return client.GetSlabsWithContext(context.Background()) +} + +// GetSlabsWithContext returns Slabs stats with a context. +func (client *NginxClient) GetSlabsWithContext(ctx context.Context) (*Slabs, error) { var slabs Slabs - err := client.get("slabs", &slabs) + err := client.getWithContext(ctx, "slabs", &slabs) if err != nil { return nil, fmt.Errorf("failed to get slabs: %w", err) } @@ -1543,8 +1597,13 @@ func (client *NginxClient) GetSlabs() (*Slabs, error) { // GetConnections returns Connections stats. func (client *NginxClient) GetConnections() (*Connections, error) { + return client.GetConnectionsWithContext(context.Background()) +} + +// GetConnectionsWithContext returns Connections stats with a context. +func (client *NginxClient) GetConnectionsWithContext(ctx context.Context) (*Connections, error) { var cons Connections - err := client.get("connections", &cons) + err := client.getWithContext(ctx, "connections", &cons) if err != nil { return nil, fmt.Errorf("failed to get connections: %w", err) } @@ -1553,8 +1612,13 @@ func (client *NginxClient) GetConnections() (*Connections, error) { // GetHTTPRequests returns http/requests stats. func (client *NginxClient) GetHTTPRequests() (*HTTPRequests, error) { + return client.GetHTTPRequestsWithContext(context.Background()) +} + +// GetHTTPRequestsWithContext returns http/requests stats with a context. +func (client *NginxClient) GetHTTPRequestsWithContext(ctx context.Context) (*HTTPRequests, error) { var requests HTTPRequests - err := client.get("http/requests", &requests) + err := client.getWithContext(ctx, "http/requests", &requests) if err != nil { return nil, fmt.Errorf("failed to get http requests: %w", err) } @@ -1563,8 +1627,13 @@ func (client *NginxClient) GetHTTPRequests() (*HTTPRequests, error) { // GetSSL returns SSL stats. func (client *NginxClient) GetSSL() (*SSL, error) { + return client.GetSSLWithContext(context.Background()) +} + +// GetSSLWithContext returns SSL stats with a context. +func (client *NginxClient) GetSSLWithContext(ctx context.Context) (*SSL, error) { var ssl SSL - err := client.get("ssl", &ssl) + err := client.getWithContext(ctx, "ssl", &ssl) if err != nil { return nil, fmt.Errorf("failed to get ssl: %w", err) } @@ -1573,8 +1642,13 @@ func (client *NginxClient) GetSSL() (*SSL, error) { // GetServerZones returns http/server_zones stats. func (client *NginxClient) GetServerZones() (*ServerZones, error) { + return client.GetServerZonesWithContext(context.Background()) +} + +// GetServerZonesWithContext returns http/server_zones stats with a context. +func (client *NginxClient) GetServerZonesWithContext(ctx context.Context) (*ServerZones, error) { var zones ServerZones - err := client.get("http/server_zones", &zones) + err := client.getWithContext(ctx, "http/server_zones", &zones) if err != nil { return nil, fmt.Errorf("failed to get server zones: %w", err) } @@ -1583,8 +1657,13 @@ func (client *NginxClient) GetServerZones() (*ServerZones, error) { // GetStreamServerZones returns stream/server_zones stats. func (client *NginxClient) GetStreamServerZones() (*StreamServerZones, error) { + return client.GetStreamServerZonesWithContext(context.Background()) +} + +// GetStreamServerZonesWithContext returns stream/server_zones stats with a context. +func (client *NginxClient) GetStreamServerZonesWithContext(ctx context.Context) (*StreamServerZones, error) { var zones StreamServerZones - err := client.get("stream/server_zones", &zones) + err := client.getWithContext(ctx, "stream/server_zones", &zones) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1599,8 +1678,13 @@ func (client *NginxClient) GetStreamServerZones() (*StreamServerZones, error) { // GetUpstreams returns http/upstreams stats. func (client *NginxClient) GetUpstreams() (*Upstreams, error) { + return client.GetUpstreamsWithContext(context.Background()) +} + +// GetUpstreamsWithContext returns http/upstreams stats with a context. +func (client *NginxClient) GetUpstreamsWithContext(ctx context.Context) (*Upstreams, error) { var upstreams Upstreams - err := client.get("http/upstreams", &upstreams) + err := client.getWithContext(ctx, "http/upstreams", &upstreams) if err != nil { return nil, fmt.Errorf("failed to get upstreams: %w", err) } @@ -1609,8 +1693,13 @@ func (client *NginxClient) GetUpstreams() (*Upstreams, error) { // GetStreamUpstreams returns stream/upstreams stats. func (client *NginxClient) GetStreamUpstreams() (*StreamUpstreams, error) { + return client.GetStreamUpstreamsWithContext(context.Background()) +} + +// GetStreamUpstreamsWithContext returns stream/upstreams stats with a context. +func (client *NginxClient) GetStreamUpstreamsWithContext(ctx context.Context) (*StreamUpstreams, error) { var upstreams StreamUpstreams - err := client.get("stream/upstreams", &upstreams) + err := client.getWithContext(ctx, "stream/upstreams", &upstreams) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1625,8 +1714,13 @@ func (client *NginxClient) GetStreamUpstreams() (*StreamUpstreams, error) { // GetStreamZoneSync returns stream/zone_sync stats. func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) { + return client.GetStreamZoneSyncWithContext(context.Background()) +} + +// GetStreamZoneSyncWithContext returns stream/zone_sync stats with a context. +func (client *NginxClient) GetStreamZoneSyncWithContext(ctx context.Context) (*StreamZoneSync, error) { var streamZoneSync StreamZoneSync - err := client.get("stream/zone_sync", &streamZoneSync) + err := client.getWithContext(ctx, "stream/zone_sync", &streamZoneSync) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1642,11 +1736,16 @@ func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) { // GetLocationZones returns http/location_zones stats. func (client *NginxClient) GetLocationZones() (*LocationZones, error) { + return client.GetLocationZonesWithContext(context.Background()) +} + +// GetLocationZonesWithContext returns http/location_zones stats with a context. +func (client *NginxClient) GetLocationZonesWithContext(ctx context.Context) (*LocationZones, error) { var locationZones LocationZones if client.apiVersion < 5 { return &locationZones, nil } - err := client.get("http/location_zones", &locationZones) + err := client.getWithContext(ctx, "http/location_zones", &locationZones) if err != nil { return nil, fmt.Errorf("failed to get location zones: %w", err) } @@ -1656,11 +1755,16 @@ func (client *NginxClient) GetLocationZones() (*LocationZones, error) { // GetResolvers returns Resolvers stats. func (client *NginxClient) GetResolvers() (*Resolvers, error) { + return client.GetResolversWithContext(context.Background()) +} + +// GetResolversWithContext returns Resolvers stats with a context. +func (client *NginxClient) GetResolversWithContext(ctx context.Context) (*Resolvers, error) { var resolvers Resolvers if client.apiVersion < 5 { return &resolvers, nil } - err := client.get("resolvers", &resolvers) + err := client.getWithContext(ctx, "resolvers", &resolvers) if err != nil { return nil, fmt.Errorf("failed to get resolvers: %w", err) } @@ -1670,8 +1774,13 @@ func (client *NginxClient) GetResolvers() (*Resolvers, error) { // GetProcesses returns Processes stats. func (client *NginxClient) GetProcesses() (*Processes, error) { + return client.GetProcessesWithContext(context.Background()) +} + +// GetProcessesWithContext returns Processes stats with a context. +func (client *NginxClient) GetProcessesWithContext(ctx context.Context) (*Processes, error) { var processes Processes - err := client.get("processes", &processes) + err := client.getWithContext(ctx, "processes", &processes) if err != nil { return nil, fmt.Errorf("failed to get processes: %w", err) } @@ -1901,11 +2010,16 @@ func addPortToServer(server string) string { // GetHTTPLimitReqs returns http/limit_reqs stats. func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { + return client.GetHTTPLimitReqsWithContext(context.Background()) +} + +// GetHTTPLimitReqsWithContext returns http/limit_reqs stats with a context. +func (client *NginxClient) GetHTTPLimitReqsWithContext(ctx context.Context) (*HTTPLimitRequests, error) { var limitReqs HTTPLimitRequests if client.apiVersion < 6 { return &limitReqs, nil } - err := client.get("http/limit_reqs", &limitReqs) + err := client.getWithContext(ctx, "http/limit_reqs", &limitReqs) if err != nil { return nil, fmt.Errorf("failed to get http limit requests: %w", err) } @@ -1914,11 +2028,16 @@ func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { // GetHTTPConnectionsLimit returns http/limit_conns stats. func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, error) { + return client.GetHTTPConnectionsLimitWithContext(context.Background()) +} + +// GetHTTPConnectionsLimitWithContext returns http/limit_conns stats with a context. +func (client *NginxClient) GetHTTPConnectionsLimitWithContext(ctx context.Context) (*HTTPLimitConnections, error) { var limitConns HTTPLimitConnections if client.apiVersion < 6 { return &limitConns, nil } - err := client.get("http/limit_conns", &limitConns) + err := client.getWithContext(ctx, "http/limit_conns", &limitConns) if err != nil { return nil, fmt.Errorf("failed to get http connections limit: %w", err) } @@ -1927,11 +2046,16 @@ func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, err // GetStreamConnectionsLimit returns stream/limit_conns stats. func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, error) { + return client.GetStreamConnectionsLimitWithContext(context.Background()) +} + +// GetStreamConnectionsLimitWithContext returns stream/limit_conns stats with a context. +func (client *NginxClient) GetStreamConnectionsLimitWithContext(ctx context.Context) (*StreamLimitConnections, error) { var limitConns StreamLimitConnections if client.apiVersion < 6 { return &limitConns, nil } - err := client.get("stream/limit_conns", &limitConns) + err := client.getWithContext(ctx, "stream/limit_conns", &limitConns) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1946,11 +2070,16 @@ func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, // GetWorkers returns workers stats. func (client *NginxClient) GetWorkers() ([]*Workers, error) { + return client.GetWorkersWithContext(context.Background()) +} + +// GetWorkersWithContext returns workers stats with a context. +func (client *NginxClient) GetWorkersWithContext(ctx context.Context) ([]*Workers, error) { var workers []*Workers if client.apiVersion < 9 { return workers, nil } - err := client.get("workers", &workers) + err := client.getWithContext(ctx, "workers", &workers) if err != nil { return nil, fmt.Errorf("failed to get workers: %w", err) } diff --git a/client/nginx_test.go b/client/nginx_test.go index de8070f7..e0378fef 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "net/http" "net/http/httptest" "reflect" @@ -622,23 +623,39 @@ func TestClientWithHTTPClient(t *testing.T) { } func TestGetStats_NoStreamEndpoint(t *testing.T) { + tests := []struct { + ctx context.Context + name string + }{ + { + ctx: nil, + name: "no context test", + }, + { + ctx: context.Background(), + name: "with context test", + }, + } + var err error + var client *NginxClient + t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.RequestURI == "/": - _, err := w.Write([]byte(`[4, 5, 6, 7, 8, 9]`)) + _, err = w.Write([]byte(`[4, 5, 6, 7, 8, 9]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } case r.RequestURI == "/7/": - _, err := w.Write([]byte(`["nginx","processes","connections","slabs","http","resolvers","ssl"]`)) + _, err = w.Write([]byte(`["nginx","processes","connections","slabs","http","resolvers","ssl"]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } case strings.HasPrefix(r.RequestURI, "/7/stream"): t.Fatal("Stream endpoint should not be called since it does not exist.") default: - _, err := w.Write([]byte(`{}`)) + _, err = w.Write([]byte(`{}`)) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -647,7 +664,7 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { defer ts.Close() // Test creating a new client with a supported API version on the server - client, err := NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI()) + client, err = NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -655,9 +672,19 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { t.Fatalf("client is nil") } - stats, err := client.GetStats() - if err != nil { - t.Fatalf("unexpected error: %v", err) + var stats *Stats + for _, test := range tests { + if test.ctx == nil { + stats, err = client.GetStats() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + stats, err = client.GetStatsWithContext(test.ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } } if !reflect.DeepEqual(stats.StreamServerZones, StreamServerZones{}) { @@ -675,6 +702,20 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { } func TestGetStats_SSL(t *testing.T) { + tests := []struct { + ctx context.Context + name string + }{ + { + ctx: nil, + name: "no context test", + }, + { + ctx: context.Background(), + name: "with context test", + }, + } + t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -731,30 +772,41 @@ func TestGetStats_SSL(t *testing.T) { t.Fatalf("client is nil") } - stats, err := client.GetStats() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + var stats *Stats - testStats := SSL{ - Handshakes: 79572, - HandshakesFailed: 21025, - SessionReuses: 15762, - NoCommonProtocol: 4, - NoCommonCipher: 2, - HandshakeTimeout: 0, - PeerRejectedCert: 0, - VerifyFailures: VerifyFailures{ - NoCert: 0, - ExpiredCert: 2, - RevokedCert: 1, - HostnameMismatch: 2, - Other: 1, - }, - } + for _, test := range tests { + if test.ctx == nil { + stats, err = client.GetStats() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + stats, err = client.GetStatsWithContext(test.ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } - if !reflect.DeepEqual(stats.SSL, testStats) { - t.Fatalf("SSL stats: expected %v, actual %v", testStats, stats.SSL) + testStats := SSL{ + Handshakes: 79572, + HandshakesFailed: 21025, + SessionReuses: 15762, + NoCommonProtocol: 4, + NoCommonCipher: 2, + HandshakeTimeout: 0, + PeerRejectedCert: 0, + VerifyFailures: VerifyFailures{ + NoCert: 0, + ExpiredCert: 2, + RevokedCert: 1, + HostnameMismatch: 2, + Other: 1, + }, + } + + if !reflect.DeepEqual(stats.SSL, testStats) { + t.Fatalf("SSL stats: expected %v, actual %v", testStats, stats.SSL) + } } } From 56871ef97bea09760e23f28a1588569da3067edc Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Thu, 29 Aug 2024 15:28:43 +0100 Subject: [PATCH 5/7] modified the locks and unlocks to be consistent --- client/nginx.go | 102 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 32 deletions(-) diff --git a/client/nginx.go b/client/nginx.go index 4388497d..e08a5e5a 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -965,10 +965,10 @@ func (client *NginxClient) getWithContext(ctx context.Context, path string, data } func (client *NginxClient) post(path string, input interface{}) error { - return client.postWithConext(context.Background(), path, input) + return client.postWithContext(context.Background(), path, input) } -func (client *NginxClient) postWithConext(ctx context.Context, path string, input interface{}) error { +func (client *NginxClient) postWithContext(ctx context.Context, path string, input interface{}) error { ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) defer cancel() @@ -1259,156 +1259,183 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err stats := defaultStats() // Collecting initial stats initialGroup.Go(func() error { - mu.Lock() endpoints, err := client.GetAvailableEndpointsWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get available Endpoints: %w", err) } + + mu.Lock() stats.endpoints = endpoints + mu.Unlock() return nil }) initialGroup.Go(func() error { - mu.Lock() nginxInfo, err := client.GetNginxInfoWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get NGINX info: %w", err) } + + mu.Lock() stats.NginxInfo = *nginxInfo + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() caches, err := client.GetCachesWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Caches: %w", err) } + + mu.Lock() stats.Caches = *caches + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() processes, err := client.GetProcessesWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Process information: %w", err) } + + mu.Lock() stats.Processes = *processes + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() slabs, err := client.GetSlabsWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Slabs: %w", err) } + + mu.Lock() stats.Slabs = *slabs + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() httpRequests, err := client.GetHTTPRequestsWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTP Requests: %w", err) } + + mu.Lock() stats.HTTPRequests = *httpRequests + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() ssl, err := client.GetSSLWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get SSL: %w", err) } + + mu.Lock() stats.SSL = *ssl + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() serverZones, err := client.GetServerZonesWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Server Zones: %w", err) } + + mu.Lock() stats.ServerZones = *serverZones + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() upstreams, err := client.GetUpstreamsWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Upstreams: %w", err) } + + mu.Lock() stats.Upstreams = *upstreams + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() locationZones, err := client.GetLocationZonesWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Location Zones: %w", err) } + + mu.Lock() stats.LocationZones = *locationZones + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() resolvers, err := client.GetResolversWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Resolvers: %w", err) } + + mu.Lock() stats.Resolvers = *resolvers + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() httpLimitRequests, err := client.GetHTTPLimitReqsWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTPLimitRequests: %w", err) } + + mu.Lock() stats.HTTPLimitRequests = *httpLimitRequests + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() httpLimitConnections, err := client.GetHTTPConnectionsLimitWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get HTTPLimitConnections: %w", err) } + + mu.Lock() stats.HTTPLimitConnections = *httpLimitConnections + mu.Unlock() + return nil }) initialGroup.Go(func() error { - mu.Lock() workers, err := client.GetWorkersWithContext(initialCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get Workers: %w", err) } + + mu.Lock() stats.Workers = workers + mu.Unlock() + return nil }) @@ -1421,13 +1448,15 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err availableStreamGroup, asgCtx := errgroup.WithContext(ctx) availableStreamGroup.Go(func() error { - mu.Lock() streamEndpoints, err := client.GetAvailableStreamEndpointsWithContext(asgCtx) - mu.Unlock() if err != nil { return fmt.Errorf("failed to get available Stream Endpoints: %w", err) } + + mu.Lock() stats.streamEndpoints = streamEndpoints + mu.Unlock() + return nil }) @@ -1443,9 +1472,11 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err if err != nil { return fmt.Errorf("failed to get streamServerZones: %w", err) } + mu.Lock() stats.StreamServerZones = *streamServerZones mu.Unlock() + return nil }) } @@ -1456,6 +1487,7 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err if err != nil { return fmt.Errorf("failed to get StreamUpstreams: %w", err) } + mu.Lock() stats.StreamUpstreams = *streamUpstreams mu.Unlock() @@ -1470,9 +1502,11 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err if err != nil { return fmt.Errorf("failed to get StreamLimitConnections: %w", err) } + mu.Lock() stats.StreamLimitConnections = *streamConnectionsLimit mu.Unlock() + return nil }) @@ -1481,9 +1515,11 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err if err != nil { return fmt.Errorf("failed to get StreamZoneSync: %w", err) } + mu.Lock() stats.StreamZoneSync = streamZoneSync mu.Unlock() + return nil }) } @@ -1502,9 +1538,11 @@ func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, err if err != nil { return fmt.Errorf("failed to get connections: %w", err) } + mu.Lock() stats.Connections = *connections mu.Unlock() + return nil }) From 1f1894bf03b85c07bb3b4e2c0459243e0bd95c40 Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Thu, 29 Aug 2024 15:35:45 +0100 Subject: [PATCH 6/7] fix race detection for server writes --- client/nginx_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/nginx_test.go b/client/nginx_test.go index e0378fef..0ad650c9 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "reflect" "strings" + "sync" "testing" "time" ) @@ -638,11 +639,16 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { } var err error var client *NginxClient + var writeLock sync.Mutex t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeLock.Lock() + defer writeLock.Unlock() + switch { case r.RequestURI == "/": + _, err = w.Write([]byte(`[4, 5, 6, 7, 8, 9]`)) if err != nil { t.Fatalf("unexpected error: %v", err) From c42f3e87424493ea3cdce9ea671f0e562bfd662b Mon Sep 17 00:00:00 2001 From: Oliver O'Mahony Date: Thu, 29 Aug 2024 16:42:54 +0100 Subject: [PATCH 7/7] refactored the functions to be aligned with the older way of doing things --- client/nginx.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/client/nginx.go b/client/nginx.go index e08a5e5a..7a76d41c 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -927,13 +927,13 @@ func (client *NginxClient) getIDOfHTTPServer(upstream string, name string) (int, } func (client *NginxClient) get(path string, data interface{}) error { - return client.getWithContext(context.Background(), path, data) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + defer cancel() + + return client.getWithContext(timeoutCtx, path, data) } func (client *NginxClient) getWithContext(ctx context.Context, path string, data interface{}) error { - ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) - defer cancel() - url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -965,13 +965,13 @@ func (client *NginxClient) getWithContext(ctx context.Context, path string, data } func (client *NginxClient) post(path string, input interface{}) error { - return client.postWithContext(context.Background(), path, input) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + defer cancel() + + return client.postWithContext(timeoutCtx, path, input) } func (client *NginxClient) postWithContext(ctx context.Context, path string, input interface{}) error { - ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) - defer cancel() - url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input) @@ -1001,13 +1001,13 @@ func (client *NginxClient) postWithContext(ctx context.Context, path string, inp } func (client *NginxClient) delete(path string, expectedStatusCode int) error { - return client.deleteWithContext(context.Background(), path, expectedStatusCode) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + defer cancel() + + return client.deleteWithContext(timeoutCtx, path, expectedStatusCode) } func (client *NginxClient) deleteWithContext(ctx context.Context, path string, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) - defer cancel() - path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, path, nil) @@ -1030,13 +1030,13 @@ func (client *NginxClient) deleteWithContext(ctx context.Context, path string, e } func (client *NginxClient) patch(path string, input interface{}, expectedStatusCode int) error { - return client.patchWithContext(context.Background(), path, input, expectedStatusCode) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + defer cancel() + + return client.patchWithContext(timeoutCtx, path, input, expectedStatusCode) } func (client *NginxClient) patchWithContext(ctx context.Context, path string, input interface{}, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(ctx, client.ctxTimeout) - defer cancel() - path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input)