Skip to content

Commit f648e09

Browse files
authored
feat(bedrock): automatically infer AWS Region (#974)
1 parent fa5a09b commit f648e09

File tree

3 files changed

+100
-10
lines changed

3 files changed

+100
-10
lines changed

src/anthropic/lib/bedrock/_client.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
import logging
45
import urllib.parse
56
from typing import Any, Union, Mapping, TypeVar
67
from typing_extensions import Self, override
@@ -26,6 +27,8 @@
2627
from ...resources.messages import Messages, AsyncMessages
2728
from ...resources.completions import Completions, AsyncCompletions
2829

30+
log: logging.Logger = logging.getLogger(__name__)
31+
2932
DEFAULT_VERSION = "bedrock-2023-05-31"
3033

3134
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
@@ -64,6 +67,29 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
6467
return options
6568

6669

70+
def _infer_region() -> str:
71+
"""
72+
Infer the AWS region from the environment variables or
73+
from the boto3 session if available.
74+
"""
75+
aws_region = os.environ.get("AWS_REGION")
76+
if aws_region is None:
77+
try:
78+
import boto3
79+
80+
session = boto3.Session()
81+
if session.region_name:
82+
aws_region = session.region_name
83+
except ImportError:
84+
pass
85+
86+
if aws_region is None:
87+
log.warning("No AWS region specified, defaulting to us-east-1")
88+
aws_region = "us-east-1" # fall back to legacy behavior
89+
90+
return aws_region
91+
92+
6793
class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
6894
@override
6995
def _make_status_error(
@@ -135,9 +161,7 @@ def __init__(
135161

136162
self.aws_access_key = aws_access_key
137163

138-
if aws_region is None:
139-
aws_region = os.environ.get("AWS_REGION") or "us-east-1"
140-
self.aws_region = aws_region
164+
self.aws_region = _infer_region() if aws_region is None else aws_region
141165
self.aws_profile = aws_profile
142166

143167
self.aws_session_token = aws_session_token
@@ -279,9 +303,7 @@ def __init__(
279303

280304
self.aws_access_key = aws_access_key
281305

282-
if aws_region is None:
283-
aws_region = os.environ.get("AWS_REGION") or "us-east-1"
284-
self.aws_region = aws_region
306+
self.aws_region = _infer_region() if aws_region is None else aws_region
285307
self.aws_profile = aws_profile
286308

287309
self.aws_session_token = aws_session_token

tests/lib/test_bedrock.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
2-
from typing import cast
2+
import typing as t
3+
import tempfile
4+
from typing import TypedDict, cast
35
from typing_extensions import Protocol
46

57
import httpx
@@ -24,6 +26,41 @@ class MockRequestCall(Protocol):
2426
request: httpx.Request
2527

2628

29+
class AwsConfigProfile(TypedDict):
30+
# Available regions: https://docs.aws.amazon.com/global-infrastructure/latest/regions/aws-regions.html#available-regions
31+
name: t.Union[t.Literal["default"], str]
32+
region: str
33+
34+
35+
def profile_to_ini(profile: AwsConfigProfile) -> str:
36+
"""
37+
Convert an AWS config profile to an INI format string.
38+
"""
39+
40+
profile_name = profile["name"] if profile["name"] == "default" else f"profile {profile['name']}"
41+
return f"[{profile_name}]\nregion = {profile['region']}\n"
42+
43+
44+
@pytest.fixture
45+
def profiles() -> t.List[AwsConfigProfile]:
46+
return [
47+
{"name": "default", "region": "us-east-2"},
48+
]
49+
50+
51+
@pytest.fixture
52+
def mock_aws_config(
53+
profiles: t.List[AwsConfigProfile],
54+
monkeypatch: t.Any,
55+
) -> t.Iterable[None]:
56+
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as temp_file:
57+
for profile in profiles:
58+
temp_file.write(profile_to_ini(profile))
59+
temp_file.flush()
60+
monkeypatch.setenv("AWS_CONFIG_FILE", str(temp_file.name))
61+
yield
62+
63+
2764
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
2865
@pytest.mark.respx()
2966
def test_messages_retries(respx_mock: MockRouter) -> None:
@@ -127,3 +164,34 @@ def test_application_inference_profile(respx_mock: MockRouter) -> None:
127164
calls[1].request.url
128165
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile%2Fjf2sje1c0jnb/invoke"
129166
)
167+
168+
169+
def test_region_infer_from_profile(
170+
mock_aws_config: None, # noqa: ARG001
171+
profiles: t.List[AwsConfigProfile],
172+
) -> None:
173+
client = AnthropicBedrock()
174+
assert client.aws_region == profiles[0]["region"]
175+
176+
177+
@pytest.mark.parametrize(
178+
"profiles, aws_profile",
179+
[
180+
pytest.param([{"name": "default", "region": "us-east-2"}], "default", id="default profile"),
181+
pytest.param(
182+
[{"name": "default", "region": "us-east-2"}, {"name": "custom", "region": "us-west-1"}],
183+
"custom",
184+
id="custom profile",
185+
),
186+
],
187+
)
188+
def test_region_infer_from_specified_profile(
189+
mock_aws_config: None, # noqa: ARG001
190+
profiles: t.List[AwsConfigProfile],
191+
aws_profile: str,
192+
monkeypatch: t.Any,
193+
) -> None:
194+
monkeypatch.setenv("AWS_PROFILE", aws_profile)
195+
client = AnthropicBedrock()
196+
197+
assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]

tests/lib/test_vertex.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_global_region_base_url(self) -> None:
124124
client = AnthropicVertex(region="global", project_id="test-project", access_token="fake-token")
125125
assert str(client.base_url).rstrip("/") == "https://aiplatform.googleapis.com/v1"
126126

127-
@pytest.mark.parametrize('region', ["us-central1", "europe-west1", "asia-southeast1"])
127+
@pytest.mark.parametrize("region", ["us-central1", "europe-west1", "asia-southeast1"])
128128
def test_regional_base_url(self, region: str) -> None:
129129
"""Test that regional endpoints use the correct base URL format."""
130130
client = AnthropicVertex(region=region, project_id="test-project", access_token="fake-token")
@@ -138,7 +138,7 @@ def test_env_var_base_url_override(self, monkeypatch: pytest.MonkeyPatch) -> Non
138138
monkeypatch.setenv("ANTHROPIC_VERTEX_BASE_URL", test_url)
139139

140140
client = AnthropicVertex(
141-
region="global", # we expect this to get ignored since the user is providing a base_url
141+
region="global", # we expect this to get ignored since the user is providing a base_url
142142
project_id="test-project",
143143
access_token="fake-token",
144144
base_url="https://test.googleapis.com/v1",
@@ -270,7 +270,7 @@ def test_env_var_base_url_override(self, monkeypatch: pytest.MonkeyPatch) -> Non
270270
monkeypatch.setenv("ANTHROPIC_VERTEX_BASE_URL", test_url)
271271

272272
client = AsyncAnthropicVertex(
273-
region="global", # we expect this to get ignored since the user is providing a base_url
273+
region="global", # we expect this to get ignored since the user is providing a base_url
274274
project_id="test-project",
275275
access_token="fake-token",
276276
base_url="https://test.googleapis.com/v1",

0 commit comments

Comments
 (0)