Skip to content

feat: support customizing OpenAPI auth scheme #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ The application is configurable via environment variables.
- **Type:** boolean
- **Required:** No, defaults to `true`
- **Example:** `false`, `1`, `True`
- OpenAPI
- **`OPENAPI_SPEC_ENDPOINT`**, path of OpenAPI specification, used for augmenting spec response with auth configuration
- **Type:** string or null
- **Required:** No, defaults to `null` (disabled)
- **Example:** `/api`
- **`OPENAPI_AUTH_SCHEME_NAME`**, name of the auth scheme to use in the OpenAPI spec
- **Type:** string
- **Required:** No, defaults to `oidcAuth`
- **Example:** `jwtAuth`
- **`OPENAPI_AUTH_SCHEME_OVERRIDE`**, override for the auth scheme in the OpenAPI spec
- **Type:** JSON object
- **Required:** No, defaults to `null` (disabled)
- **Example:** `{"type": "http", "scheme": "bearer", "bearerFormat": "JWT", "description": "Paste your raw JWT here. This API uses Bearer token authorization.\n"}`
- Filtering
- **`ITEMS_FILTER_CLS`**, CQL2 expression generator for item-level filtering
- **Type:** JSON object with class configuration
Expand All @@ -139,7 +148,7 @@ The application is configurable via environment variables.
- **`ITEMS_FILTER_KWARGS`**, Keyword arguments for CQL2 expression generator
- **Type:** Dictionary of keyword arguments used to initialize the class
- **Required:** No, defaults to `{}`
- **Example:** `{ "field_name": "properties.organization" }`
- **Example:** `{"field_name": "properties.organization"}`

### Customization

Expand Down
2 changes: 2 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ async def lifespan(app: FastAPI):
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
auth_scheme_name=settings.openapi_auth_scheme_name,
auth_scheme_override=settings.openapi_auth_scheme_override,
)

if settings.items_filter:
Expand Down
7 changes: 5 additions & 2 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ class Settings(BaseSettings):
oidc_discovery_url: HttpUrl
oidc_discovery_internal_url: HttpUrl

healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
wait_for_upstream: bool = True
check_conformance: bool = True
enable_compression: bool = True
enable_authentication_extension: bool = True
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")

openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
openapi_auth_scheme_name: str = "oidcAuth"
openapi_auth_scheme_override: Optional[dict] = None

# Auth
enable_authentication_extension: bool = True
default_public: bool = False
public_endpoints: EndpointMethodsNoScope = {
r"^/api.html$": ["GET"],
Expand Down
9 changes: 5 additions & 4 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional

from starlette.datastructures import Headers
from starlette.requests import Request
Expand All @@ -23,7 +23,8 @@ class OpenApiMiddleware(JsonResponseMiddleware):
private_endpoints: EndpointMethods
public_endpoints: EndpointMethods
default_public: bool
oidc_auth_scheme_name: str = "oidcAuth"
auth_scheme_name: str = "oidcAuth"
auth_scheme_override: Optional[dict] = None

json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"

Expand All @@ -47,7 +48,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
"""Augment the OpenAPI spec with auth information."""
components = data.setdefault("components", {})
securitySchemes = components.setdefault("securitySchemes", {})
securitySchemes[self.oidc_auth_scheme_name] = {
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}
Expand All @@ -62,6 +63,6 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
)
if match.is_private:
config.setdefault("security", []).append(
{self.oidc_auth_scheme_name: match.required_scopes}
{self.auth_scheme_name: match.required_scopes}
)
return data
39 changes: 39 additions & 0 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,42 @@ def test_oidc_in_openapi_spec_public_endpoints(
assert any(
method.casefold() == m.casefold() for m in expected_auth[path]
)


def test_auth_scheme_name_override(source_api: FastAPI, source_api_server: str):
"""When auth_scheme_name is overridden, the OpenAPI spec uses the custom name."""
custom_name = "customAuth"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
openapi_auth_scheme_name=custom_name,
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
assert custom_name in security_schemes
assert "oidcAuth" not in security_schemes


def test_auth_scheme_override(source_api: FastAPI, source_api_server: str):
"""When auth_scheme_override is provided, the OpenAPI spec uses the custom scheme."""
custom_scheme = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Custom JWT authentication",
}
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
openapi_auth_scheme_override=custom_scheme,
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
assert "oidcAuth" in security_schemes
assert security_schemes["oidcAuth"] == custom_scheme