Skip to content

Commit bce4ed8

Browse files
authored
Fix: Bedrock Application Inference ARN breaks Claude Sonnet tools usage (#561)
### Description This PR fixes the issue when running a agent using tools with application inference profile of Claude Sonnet 3.7 as input. Note that using regional inference profile is working well, and only application inference profile (AIP) will not return any tools use information. ### Root Cause Langchain needs to identify the model id of the foundation model but fails to detect the foundation model from the input AIP ARN (e.g., `arn:aws:bedrock:us-east-1:111111484058:application-inference-profile/c3myu2h6fllr`), while regional inference profile contains the foundation model name `us.anthropic.claude-3-7-sonnet-20250219-v1:0` and can be easily used to identify the base model id. Therefore, regional inference profile is working well while AIP is not. ### Solution To identify the foundation model used in AIP, we need to call Bedrock `get_inference_profile` control plane [API](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/get_inference_profile.html), and parse the model id from the response. ### Issue #535
1 parent 89f5848 commit bce4ed8

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,10 @@ class BedrockBase(BaseLanguageModel, ABC):
591591
"""Base class for Bedrock models."""
592592

593593
client: Any = Field(default=None, exclude=True) #: :meta private:
594+
"""The bedrock runtime client for making data plane API calls"""
595+
596+
bedrock_client: Any = Field(default=None, exclude=True) #: :meta private:
597+
"""The bedrock client for making control plane API calls"""
594598

595599
region_name: Optional[str] = Field(default=None, alias="region")
596600
"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION or AWS_DEFAULT_REGION
@@ -775,6 +779,19 @@ def validate_environment(self) -> Self:
775779
service_name="bedrock-runtime",
776780
)
777781

782+
# Create bedrock client for control plane API call
783+
if self.bedrock_client is None:
784+
self.bedrock_client = create_aws_client(
785+
region_name=self.region_name,
786+
credentials_profile_name=self.credentials_profile_name,
787+
aws_access_key_id=self.aws_access_key_id,
788+
aws_secret_access_key=self.aws_secret_access_key,
789+
aws_session_token=self.aws_session_token,
790+
endpoint_url=self.endpoint_url,
791+
config=self.config,
792+
service_name="bedrock",
793+
)
794+
778795
return self
779796

780797
@property
@@ -815,6 +832,16 @@ def _get_provider(self) -> str:
815832
)
816833

817834
def _get_base_model(self) -> str:
835+
# identify the base model id used in the application inference profile (AIP)
836+
# Format: arn:aws:bedrock:us-east-1:<accountId>:application-inference-profile/<id>
837+
if self.base_model_id is None and 'application-inference-profile' in self.model_id:
838+
response = self.bedrock_client.get_inference_profile(
839+
inferenceProfileIdentifier=self.model_id
840+
)
841+
if 'models' in response and len(response['models']) > 0:
842+
model_arn = response['models'][0]['modelArn']
843+
# Format: arn:aws:bedrock:region::foundation-model/provider.model-name
844+
self.base_model_id = model_arn.split('/')[-1]
818845
return self.base_model_id if self.base_model_id else self.model_id.split(".", maxsplit=1)[-1]
819846

820847
@property

libs/aws/tests/unit_tests/llms/test_bedrock.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,3 +781,61 @@ def test__get_base_model():
781781
region_name="us-west-2"
782782
)
783783
assert llm._get_base_model() == "meta.llama3-8b-instruct-v1:0"
784+
785+
786+
@patch("langchain_aws.llms.bedrock.create_aws_client")
787+
def test_bedrock_client_creation(mock_create_client):
788+
"""Test that both bedrock-runtime and bedrock clients are created."""
789+
mock_runtime_client = MagicMock()
790+
mock_bedrock_client = MagicMock()
791+
mock_create_client.side_effect = [mock_runtime_client, mock_bedrock_client]
792+
793+
llm = BedrockLLM(
794+
model_id="meta.llama3-8b-instruct-v1:0",
795+
region_name="us-west-2"
796+
)
797+
798+
# Should create both clients
799+
assert mock_create_client.call_count == 2
800+
801+
# Check that bedrock-runtime client was created
802+
calls = mock_create_client.call_args_list
803+
runtime_call = calls[0]
804+
assert runtime_call.kwargs["service_name"] == "bedrock-runtime"
805+
assert runtime_call.kwargs["region_name"] == "us-west-2"
806+
807+
# Check that bedrock client was created
808+
bedrock_call = calls[1]
809+
assert bedrock_call.kwargs["service_name"] == "bedrock"
810+
assert bedrock_call.kwargs["region_name"] == "us-west-2"
811+
812+
assert llm.client is mock_runtime_client
813+
assert llm.bedrock_client is mock_bedrock_client
814+
815+
816+
@patch("langchain_aws.llms.bedrock.create_aws_client")
817+
def test_get_base_model_with_application_inference_profile(mock_create_client):
818+
"""Test _get_base_model with application inference profile."""
819+
mock_runtime_client = MagicMock()
820+
mock_bedrock_client = MagicMock()
821+
mock_bedrock_client.get_inference_profile.return_value = {
822+
"models": [
823+
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0"}
824+
]
825+
}
826+
mock_create_client.side_effect = [mock_runtime_client, mock_bedrock_client]
827+
828+
llm = BedrockLLM(
829+
model_id="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/my-profile",
830+
provider="anthropic",
831+
region_name="us-west-2"
832+
)
833+
834+
result = llm._get_base_model()
835+
836+
# Should call get_inference_profile and extract base model
837+
mock_bedrock_client.get_inference_profile.assert_called_once_with(
838+
inferenceProfileIdentifier="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/my-profile"
839+
)
840+
assert result == "anthropic.claude-3-sonnet-20240229-v1:0"
841+
assert llm.base_model_id == "anthropic.claude-3-sonnet-20240229-v1:0"

0 commit comments

Comments
 (0)