diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f577..b49680e92 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -151,7 +151,7 @@ def get_resource_url(self) -> str: # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) + prm_resource = str(self.protected_resource_metadata.resource).rstrip("/") if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): resource = prm_resource @@ -442,10 +442,6 @@ async def _refresh_token(self) -> httpx.Request: "client_id": self.context.client_info.client_id, } - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - # Prepare authentication based on preferred method headers = {"Content-Type": "application/x-www-form-urlencoded"} refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..a1993e646 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -762,7 +762,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ expected_resource = quote(oauth_provider.context.get_resource_url(), safe="") assert f"resource={expected_resource}" in content - # Test in refresh token + # Refresh grants don't carry resource; some OAuth providers reject it. oauth_provider.context.current_tokens = OAuthToken( access_token="test_access", token_type="Bearer", @@ -770,7 +770,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ ) refresh_request = await oauth_provider._refresh_token() refresh_content = refresh_request.content.decode() - assert "resource=" in refresh_content + assert "resource=" not in refresh_content @pytest.mark.anyio async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): @@ -800,7 +800,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro @pytest.mark.anyio async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): - """Test resource parameter is always included when protected resource metadata exists.""" + """Test resource parameter is included when protected resource metadata exists.""" # Set old protocol version but with protected resource metadata oauth_provider.context.protocol_version = "2025-03-26" oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata( @@ -818,6 +818,15 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa content = request.content.decode() assert "resource=" in content + oauth_provider.context.current_tokens = OAuthToken( + access_token="test_access", + token_type="Bearer", + refresh_token="test_refresh", + ) + refresh_request = await oauth_provider._refresh_token() + refresh_content = refresh_request.content.decode() + assert "resource=" not in refresh_content + @pytest.mark.anyio async def test_validate_resource_rejects_mismatched_resource( @@ -949,6 +958,27 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches( assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp") +@pytest.mark.anyio +async def test_get_resource_url_omits_pydantic_root_slash( + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> None: + """Bare-domain PRM resources should not inherit Pydantic's trailing slash.""" + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + + assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/" + assert provider.context.get_resource_url() == "https://api.example.com" + + class TestRegistrationResponse: """Test client registration response handling."""