diff --git a/README.md b/README.md index e0f8113..a390905 100644 --- a/README.md +++ b/README.md @@ -123,10 +123,19 @@ The application is configurable via environment variables. - **Type:** boolean - **Required:** No, defaults to `true` - **Example:** `false`, `1`, `True` +- OpenAPI - **`OPENAPI_SPEC_ENDPOINT`**, path of OpenAPI specification, used for augmenting spec response with auth configuration - **Type:** string or null - **Required:** No, defaults to `null` (disabled) - **Example:** `/api` + - **`OPENAPI_AUTH_SCHEME_NAME`**, name of the auth scheme to use in the OpenAPI spec + - **Type:** string + - **Required:** No, defaults to `oidcAuth` + - **Example:** `jwtAuth` + - **`OPENAPI_AUTH_SCHEME_OVERRIDE`**, override for the auth scheme in the OpenAPI spec + - **Type:** JSON object + - **Required:** No, defaults to `null` (disabled) + - **Example:** `{"type": "http", "scheme": "bearer", "bearerFormat": "JWT", "description": "Paste your raw JWT here. This API uses Bearer token authorization.\n"}` - Filtering - **`ITEMS_FILTER_CLS`**, CQL2 expression generator for item-level filtering - **Type:** JSON object with class configuration @@ -139,7 +148,7 @@ The application is configurable via environment variables. - **`ITEMS_FILTER_KWARGS`**, Keyword arguments for CQL2 expression generator - **Type:** Dictionary of keyword arguments used to initialize the class - **Required:** No, defaults to `{}` - - **Example:** `{ "field_name": "properties.organization" }` + - **Example:** `{"field_name": "properties.organization"}` ### Customization diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index ded0a73..03c26be 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -103,6 +103,8 @@ async def lifespan(app: FastAPI): public_endpoints=settings.public_endpoints, private_endpoints=settings.private_endpoints, default_public=settings.default_public, + auth_scheme_name=settings.openapi_auth_scheme_name, + auth_scheme_override=settings.openapi_auth_scheme_override, ) if settings.items_filter: diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 00b0b66..aa4690f 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -38,14 +38,17 @@ class Settings(BaseSettings): oidc_discovery_url: HttpUrl oidc_discovery_internal_url: HttpUrl + healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") wait_for_upstream: bool = True check_conformance: bool = True enable_compression: bool = True - enable_authentication_extension: bool = True - healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") + openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None) + openapi_auth_scheme_name: str = "oidcAuth" + openapi_auth_scheme_override: Optional[dict] = None # Auth + enable_authentication_extension: bool = True default_public: bool = False public_endpoints: EndpointMethodsNoScope = { r"^/api.html$": ["GET"], diff --git a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py index fbbbfa0..16c1240 100644 --- a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +++ b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from starlette.datastructures import Headers from starlette.requests import Request @@ -23,7 +23,8 @@ class OpenApiMiddleware(JsonResponseMiddleware): private_endpoints: EndpointMethods public_endpoints: EndpointMethods default_public: bool - oidc_auth_scheme_name: str = "oidcAuth" + auth_scheme_name: str = "oidcAuth" + auth_scheme_override: Optional[dict] = None json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)" @@ -47,7 +48,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An """Augment the OpenAPI spec with auth information.""" components = data.setdefault("components", {}) securitySchemes = components.setdefault("securitySchemes", {}) - securitySchemes[self.oidc_auth_scheme_name] = { + securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or { "type": "openIdConnect", "openIdConnectUrl": self.oidc_config_url, } @@ -62,6 +63,6 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An ) if match.is_private: config.setdefault("security", []).append( - {self.oidc_auth_scheme_name: match.required_scopes} + {self.auth_scheme_name: match.required_scopes} ) return data diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 183f318..e0ad990 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -151,3 +151,42 @@ def test_oidc_in_openapi_spec_public_endpoints( assert any( method.casefold() == m.casefold() for m in expected_auth[path] ) + + +def test_auth_scheme_name_override(source_api: FastAPI, source_api_server: str): + """When auth_scheme_name is overridden, the OpenAPI spec uses the custom name.""" + custom_name = "customAuth" + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint=source_api.openapi_url, + openapi_auth_scheme_name=custom_name, + ) + client = TestClient(app) + response = client.get(source_api.openapi_url) + assert response.status_code == 200 + openapi = response.json() + security_schemes = openapi.get("components", {}).get("securitySchemes", {}) + assert custom_name in security_schemes + assert "oidcAuth" not in security_schemes + + +def test_auth_scheme_override(source_api: FastAPI, source_api_server: str): + """When auth_scheme_override is provided, the OpenAPI spec uses the custom scheme.""" + custom_scheme = { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": "Custom JWT authentication", + } + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint=source_api.openapi_url, + openapi_auth_scheme_override=custom_scheme, + ) + client = TestClient(app) + response = client.get(source_api.openapi_url) + assert response.status_code == 200 + openapi = response.json() + security_schemes = openapi.get("components", {}).get("securitySchemes", {}) + assert "oidcAuth" in security_schemes + assert security_schemes["oidcAuth"] == custom_scheme