Skip to content

feat: proxy headers, override host #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ The application is configurable via environment variables.
- **Type:** string
- **Required:** No, defaults to `/healthz`
- **Example:** `''` (disabled)
- **`OVERRIDE_HOST`**, override the host header for the upstream API
- **Type:** boolean
- **Required:** No, defaults to `true`
- **Example:** `false`, `1`, `True`
- Authentication
- **`OIDC_DISCOVERY_URL`**, OpenID Connect discovery document URL
- **Type:** HTTP(S) URL
Expand Down
5 changes: 4 additions & 1 deletion src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ async def lifespan(app: FastAPI):

app.add_api_route(
"/{path:path}",
ReverseProxyHandler(upstream=str(settings.upstream_url)).proxy_request,
ReverseProxyHandler(
upstream=str(settings.upstream_url),
override_host=settings.override_host,
).proxy_request,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)

Expand Down
1 change: 1 addition & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Settings(BaseSettings):
oidc_discovery_url: HttpUrl
oidc_discovery_internal_url: HttpUrl

override_host: bool = True
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
wait_for_upstream: bool = True
check_conformance: bool = True
Expand Down
33 changes: 30 additions & 3 deletions src/stac_auth_proxy/handlers/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class ReverseProxyHandler:
client: httpx.AsyncClient = None
timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0))

proxy_name: str = "stac-auth-proxy"
override_host: bool = True
legacy_forwarded_headers: bool = False

def __post_init__(self):
"""Initialize the HTTP client."""
self.client = self.client or httpx.AsyncClient(
Expand All @@ -28,11 +32,34 @@ def __post_init__(self):
http2=True,
)

def _prepare_headers(self, request: Request) -> MutableHeaders:
"""Prepare headers for the proxied request."""
headers = MutableHeaders(request.headers)
headers.setdefault("Via", f"1.1 {self.proxy_name}")

proxy_client = request.client.host if request.client else "unknown"
proxy_proto = request.url.scheme
proxy_host = request.url.netloc
proxy_path = request.base_url.path
headers.setdefault(
"Forwarded",
f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
)
if self.legacy_forwarded_headers:
headers.setdefault("X-Forwarded-For", proxy_client)
headers.setdefault("X-Forwarded-Host", proxy_host)
headers.setdefault("X-Forwarded-Path", proxy_path)
headers.setdefault("X-Forwarded-Proto", proxy_proto)

# Set host to the upstream host
if self.override_host:
headers["Host"] = self.client.base_url.netloc.decode("utf-8")

return headers

async def proxy_request(self, request: Request) -> Response:
"""Proxy a request to the upstream STAC API."""
headers = MutableHeaders(request.headers)
headers.setdefault("X-Forwarded-For", request.client.host)
headers.setdefault("X-Forwarded-Host", request.url.hostname)
headers = self._prepare_headers(request)

# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
rp_req = self.client.build_request(
Expand Down
173 changes: 173 additions & 0 deletions tests/test_reverse_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Tests for the reverse proxy handler's header functionality."""

import pytest
from fastapi import Request

from stac_auth_proxy.handlers.reverse_proxy import ReverseProxyHandler


@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
scope = {
"type": "http",
"method": "GET",
"path": "/test",
"headers": [
(b"host", b"localhost:8000"),
(b"user-agent", b"test-agent"),
(b"accept", b"application/json"),
],
}
return Request(scope)


@pytest.fixture
def reverse_proxy_handler():
"""Create a reverse proxy handler instance."""
return ReverseProxyHandler(upstream="http://upstream-api.com")


@pytest.mark.asyncio
async def test_basic_headers(mock_request, reverse_proxy_handler):
"""Test that basic headers are properly set."""
headers = reverse_proxy_handler._prepare_headers(mock_request)

# Check standard headers
assert headers["Host"] == "upstream-api.com"
assert headers["User-Agent"] == "test-agent"
assert headers["Accept"] == "application/json"

# Check modern forwarded header
assert "Forwarded" in headers
forwarded = headers["Forwarded"]
assert "for=unknown" in forwarded
assert "host=localhost:8000" in forwarded
assert "proto=http" in forwarded
assert "path=/" in forwarded

# Check Via header
assert headers["Via"] == "1.1 stac-auth-proxy"

# Legacy headers should not be present by default
assert "X-Forwarded-For" not in headers
assert "X-Forwarded-Host" not in headers
assert "X-Forwarded-Proto" not in headers
assert "X-Forwarded-Path" not in headers


@pytest.mark.asyncio
async def test_legacy_forwarded_headers(mock_request):
"""Test that legacy X-Forwarded-* headers are set when enabled."""
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=True
)
headers = handler._prepare_headers(mock_request)

# Check legacy headers
assert headers["X-Forwarded-For"] == "unknown"
assert headers["X-Forwarded-Host"] == "localhost:8000"
assert headers["X-Forwarded-Proto"] == "http"
assert headers["X-Forwarded-Path"] == "/"

# Modern Forwarded header should still be present
assert "Forwarded" in headers


@pytest.mark.asyncio
async def test_override_host_disabled(mock_request):
"""Test that host override can be disabled."""
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", override_host=False
)
headers = handler._prepare_headers(mock_request)
assert headers["Host"] == "localhost:8000"


@pytest.mark.asyncio
async def test_custom_proxy_name(mock_request):
"""Test that custom proxy name is used in Via header."""
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", proxy_name="custom-proxy"
)
headers = handler._prepare_headers(mock_request)
assert headers["Via"] == "1.1 custom-proxy"


@pytest.mark.asyncio
async def test_forwarded_headers_with_client(mock_request):
"""Test forwarded headers when client information is available."""
# Add client information to the request
mock_request.scope["client"] = ("192.168.1.1", 12345)
handler = ReverseProxyHandler(upstream="http://upstream-api.com")
headers = handler._prepare_headers(mock_request)

# Check modern Forwarded header
forwarded = headers["Forwarded"]
assert "for=192.168.1.1" in forwarded
assert "host=localhost:8000" in forwarded
assert "proto=http" in forwarded
assert "path=/" in forwarded

# Legacy headers should not be present by default
assert "X-Forwarded-For" not in headers
assert "X-Forwarded-Host" not in headers
assert "X-Forwarded-Proto" not in headers
assert "X-Forwarded-Path" not in headers


@pytest.mark.asyncio
async def test_legacy_forwarded_headers_with_client(mock_request):
"""Test legacy forwarded headers when client information is available."""
mock_request.scope["client"] = ("192.168.1.1", 12345)
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=True
)
headers = handler._prepare_headers(mock_request)

# Check legacy headers
assert headers["X-Forwarded-For"] == "192.168.1.1"
assert headers["X-Forwarded-Host"] == "localhost:8000"
assert headers["X-Forwarded-Proto"] == "http"
assert headers["X-Forwarded-Path"] == "/"

# Modern Forwarded header should still be present
assert "Forwarded" in headers


@pytest.mark.asyncio
async def test_https_proto(mock_request):
"""Test that X-Forwarded-Proto is set correctly for HTTPS."""
mock_request.scope["scheme"] = "https"
handler = ReverseProxyHandler(upstream="http://upstream-api.com")
headers = handler._prepare_headers(mock_request)

# Check modern Forwarded header
assert "proto=https" in headers["Forwarded"]

# Legacy headers should not be present by default
assert "X-Forwarded-Proto" not in headers


@pytest.mark.asyncio
async def test_https_proto_legacy(mock_request):
"""Test that X-Forwarded-Proto is set correctly for HTTPS with legacy headers."""
mock_request.scope["scheme"] = "https"
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=True
)
headers = handler._prepare_headers(mock_request)
assert headers["X-Forwarded-Proto"] == "https"
assert "proto=https" in headers["Forwarded"]


@pytest.mark.asyncio
async def test_non_standard_port(mock_request):
"""Test handling of non-standard ports in host header."""
mock_request.scope["headers"] = [
(b"host", b"localhost:8080"),
(b"user-agent", b"test-agent"),
]
handler = ReverseProxyHandler(upstream="http://upstream-api.com:8080")
headers = handler._prepare_headers(mock_request)
assert headers["Host"] == "upstream-api.com:8080"