From 6a7341f35b19e8cb3af4d89f24e17f9995812792 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Wed, 13 May 2026 18:46:13 -0700 Subject: [PATCH 1/2] fix(rate-limit): close rate-limit bypass and tighten public route limits --- apps/sim/app/api/auth/socket-token/route.ts | 12 +- apps/sim/app/api/auth/sso/providers/route.ts | 9 + .../api/chat/[identifier]/otp/route.test.ts | 9 +- .../app/api/chat/[identifier]/otp/route.ts | 24 ++- .../app/api/chat/[identifier]/sso/route.ts | 24 ++- apps/sim/app/api/telemetry/route.ts | 8 + apps/sim/app/api/templates/[id]/route.ts | 51 ++++-- .../app/api/tools/a2a/cancel-task/route.ts | 4 + .../a2a/delete-push-notification/route.ts | 7 + .../app/api/tools/a2a/get-agent-card/route.ts | 4 + .../tools/a2a/get-push-notification/route.ts | 4 + apps/sim/app/api/tools/a2a/get-task/route.ts | 4 + .../app/api/tools/a2a/resubscribe/route.ts | 4 + .../app/api/tools/a2a/send-message/route.ts | 4 + .../tools/a2a/set-push-notification/route.ts | 4 + .../users/me/settings/unsubscribe/route.ts | 13 ++ apps/sim/app/api/v1/copilot/chat/route.ts | 18 +- apps/sim/app/api/v1/middleware.ts | 1 + apps/sim/lib/api/contracts/v1/copilot.ts | 2 +- apps/sim/lib/core/rate-limiter/index.ts | 6 + .../core/rate-limiter/route-helpers.test.ts | 160 ++++++++++++++++++ .../lib/core/rate-limiter/route-helpers.ts | 74 ++++++++ 22 files changed, 393 insertions(+), 53 deletions(-) create mode 100644 apps/sim/lib/core/rate-limiter/route-helpers.test.ts create mode 100644 apps/sim/lib/core/rate-limiter/route-helpers.ts diff --git a/apps/sim/app/api/auth/socket-token/route.ts b/apps/sim/app/api/auth/socket-token/route.ts index 45151d1e496..c7b0dc618c8 100644 --- a/apps/sim/app/api/auth/socket-token/route.ts +++ b/apps/sim/app/api/auth/socket-token/route.ts @@ -1,18 +1,26 @@ import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { headers } from 'next/headers' -import { NextResponse } from 'next/server' +import { type NextRequest, NextResponse } from 'next/server' import { auth } from '@/lib/auth' import { isAuthDisabled } from '@/lib/core/config/feature-flags' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' const logger = createLogger('SocketTokenAPI') -export const POST = withRouteHandler(async () => { +export const POST = withRouteHandler(async (request: NextRequest) => { if (isAuthDisabled) { return NextResponse.json({ token: 'anonymous-socket-token' }) } + const rateLimited = await enforceIpRateLimit('socket-token', request, { + maxTokens: 30, + refillRate: 30, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + try { const hdrs = await headers() const response = await auth.api.generateOneTimeToken({ diff --git a/apps/sim/app/api/auth/sso/providers/route.ts b/apps/sim/app/api/auth/sso/providers/route.ts index 77f86734420..8428eebc1e1 100644 --- a/apps/sim/app/api/auth/sso/providers/route.ts +++ b/apps/sim/app/api/auth/sso/providers/route.ts @@ -5,6 +5,7 @@ import { type NextRequest, NextResponse } from 'next/server' import { listSsoProvidersContract } from '@/lib/api/contracts/auth' import { parseRequest } from '@/lib/api/server' import { getSession } from '@/lib/auth' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { REDACTED_MARKER } from '@/lib/core/security/redaction' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -13,6 +14,14 @@ const logger = createLogger('SSOProvidersRoute') export const GET = withRouteHandler(async (request: NextRequest) => { try { const session = await getSession() + if (!session?.user?.id) { + const rateLimited = await enforceIpRateLimit('sso-providers', request, { + maxTokens: 20, + refillRate: 20, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + } const parsed = await parseRequest(listSsoProvidersContract, request, {}) if (!parsed.success) return parsed.response const { organizationId } = parsed.data.query diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts index acd6652bf5d..547a164b069 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts @@ -423,7 +423,7 @@ describe('Chat OTP API Route', () => { expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') }) - it('skips IP rate limit when client IP is unknown', async () => { + it('folds spoofed `unknown` client IPs into a single shared bucket', async () => { requestUtilsMockFns.mockGetClientIp.mockReturnValueOnce('unknown') buildDeploymentSelect() @@ -434,8 +434,11 @@ describe('Chat OTP API Route', () => { await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) - // Only the email-scoped check should run, not the IP-scoped one - expect(mockCheckRateLimitDirect).toHaveBeenCalledTimes(1) + expect(mockCheckRateLimitDirect).toHaveBeenCalledTimes(2) + expect(mockCheckRateLimitDirect).toHaveBeenCalledWith( + expect.stringMatching(/^chat-otp:ip:.*:unknown$/), + expect.any(Object) + ) expect(mockCheckRateLimitDirect).toHaveBeenCalledWith( expect.stringContaining('chat-otp:email:'), expect.any(Object) diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index d546ef0cf77..b2e129b5fa8 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -223,20 +223,18 @@ export const POST = withRouteHandler( try { const ip = getClientIp(request) - if (ip !== 'unknown') { - const ipRateLimit = await rateLimiter.checkRateLimitDirect( - `chat-otp:ip:${identifier}:${ip}`, - OTP_IP_RATE_LIMIT + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `chat-otp:ip:${identifier}:${ip}`, + OTP_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 ) - if (!ipRateLimit.allowed) { - logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) - const retryAfter = Math.ceil( - (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 - ) - const response = createErrorResponse('Too many requests. Please try again later.', 429) - response.headers.set('Retry-After', String(retryAfter)) - return addCorsHeaders(response, request) - } + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) } const parsed = await parseRequest(requestChatEmailOtpContract, request, context, { diff --git a/apps/sim/app/api/chat/[identifier]/sso/route.ts b/apps/sim/app/api/chat/[identifier]/sso/route.ts index 812f27df5b3..c6878876a90 100644 --- a/apps/sim/app/api/chat/[identifier]/sso/route.ts +++ b/apps/sim/app/api/chat/[identifier]/sso/route.ts @@ -30,20 +30,18 @@ export const POST = withRouteHandler( const requestId = generateRequestId() const ip = getClientIp(request) - if (ip !== 'unknown') { - const ipRateLimit = await rateLimiter.checkRateLimitDirect( - `chat-sso:ip:${ip}`, - SSO_IP_RATE_LIMIT + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `chat-sso:ip:${ip}`, + SSO_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] SSO eligibility rate limit exceeded from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? SSO_IP_RATE_LIMIT.refillIntervalMs) / 1000 ) - if (!ipRateLimit.allowed) { - logger.warn(`[${requestId}] SSO eligibility rate limit exceeded from ${ip}`) - const retryAfter = Math.ceil( - (ipRateLimit.retryAfterMs ?? SSO_IP_RATE_LIMIT.refillIntervalMs) / 1000 - ) - const response = createErrorResponse('Too many requests. Please try again later.', 429) - response.headers.set('Retry-After', String(retryAfter)) - return addCorsHeaders(response, request) - } + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) } const parsed = await parseRequest(chatSSOContract, request, context) diff --git a/apps/sim/app/api/telemetry/route.ts b/apps/sim/app/api/telemetry/route.ts index bdeb0a6b109..aed019188cc 100644 --- a/apps/sim/app/api/telemetry/route.ts +++ b/apps/sim/app/api/telemetry/route.ts @@ -4,6 +4,7 @@ import { telemetryContract } from '@/lib/api/contracts/telemetry' import { parseRequest } from '@/lib/api/server' import { env } from '@/lib/core/config/env' import { isProd } from '@/lib/core/config/feature-flags' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' const logger = createLogger('TelemetryAPI') @@ -148,6 +149,13 @@ async function forwardToCollector(data: Record): Promise { + const rateLimited = await enforceIpRateLimit('telemetry', req, { + maxTokens: 60, + refillRate: 30, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + try { const parsed = await parseRequest(telemetryContract, req, {}) if (!parsed.success) return parsed.response diff --git a/apps/sim/app/api/templates/[id]/route.ts b/apps/sim/app/api/templates/[id]/route.ts index f0bbd4f0d16..bb2e4a48c97 100644 --- a/apps/sim/app/api/templates/[id]/route.ts +++ b/apps/sim/app/api/templates/[id]/route.ts @@ -7,7 +7,8 @@ import { type NextRequest, NextResponse } from 'next/server' import { templateIdParamsSchema, updateTemplateContract } from '@/lib/api/contracts/templates' import { parseRequest } from '@/lib/api/server' import { getSession } from '@/lib/auth' -import { generateRequestId } from '@/lib/core/utils/request' +import { RateLimiter } from '@/lib/core/rate-limiter' +import { generateRequestId, getClientIp } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { canAccessTemplate } from '@/lib/templates/permissions' import { @@ -18,6 +19,18 @@ import type { WorkflowState } from '@/stores/workflows/workflow/types' const logger = createLogger('TemplateByIdAPI') +const viewRateLimiter = new RateLimiter() + +/** + * Per-IP, per-template view-counter dedup bucket: one increment per 10 minutes. + * Prevents scripted inflation of `templates.views` from the public GET handler. + */ +const TEMPLATE_VIEW_DEDUP = { + maxTokens: 1, + refillRate: 1, + refillIntervalMs: 10 * 60_000, +} + export const revalidate = 0 export const GET = withRouteHandler( @@ -63,21 +76,31 @@ export const GET = withRouteHandler( isStarred = starResult.length > 0 } - const shouldIncrementView = template.status === 'approved' + let shouldIncrementView = template.status === 'approved' if (shouldIncrementView) { - try { - await db - .update(templates) - .set({ - views: sql`${templates.views} + 1`, - }) - .where(eq(templates.id, id)) - } catch (viewError) { - logger.warn( - `[${requestId}] Failed to increment view count for template: ${id}`, - viewError - ) + const viewer = session?.user?.id ?? `ip:${getClientIp(request)}` + const dedupKey = `template-view:${id}:${viewer}` + const { allowed } = await viewRateLimiter.checkRateLimitDirect( + dedupKey, + TEMPLATE_VIEW_DEDUP + ) + if (!allowed) { + shouldIncrementView = false + } else { + try { + await db + .update(templates) + .set({ + views: sql`${templates.views} + 1`, + }) + .where(eq(templates.id, id)) + } catch (viewError) { + logger.warn( + `[${requestId}] Failed to increment view count for template: ${id}`, + viewError + ) + } } } diff --git a/apps/sim/app/api/tools/a2a/cancel-task/route.ts b/apps/sim/app/api/tools/a2a/cancel-task/route.ts index f5eef3170e7..188303382a7 100644 --- a/apps/sim/app/api/tools/a2a/cancel-task/route.ts +++ b/apps/sim/app/api/tools/a2a/cancel-task/route.ts @@ -5,6 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aCancelTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -29,6 +30,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-cancel-task', authResult.userId!) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aCancelTaskContract, request, diff --git a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts index 77b46a30000..1ae6ccfde6f 100644 --- a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aDeletePushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,6 +31,12 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit( + 'a2a-delete-push-notification', + authResult.userId! + ) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A delete push notification request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts index 5a39e944d2f..513079aeaa7 100644 --- a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts +++ b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetAgentCardContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -28,6 +29,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-get-agent-card', authResult.userId!) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A get agent card request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts index 49d21b07439..4cbd2baf84b 100644 --- a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,6 +31,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-get-push-notification', authResult.userId!) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A get push notification request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-task/route.ts b/apps/sim/app/api/tools/a2a/get-task/route.ts index ac21da72537..dab4d57de98 100644 --- a/apps/sim/app/api/tools/a2a/get-task/route.ts +++ b/apps/sim/app/api/tools/a2a/get-task/route.ts @@ -5,6 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -29,6 +30,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-get-task', authResult.userId!) + if (rateLimited) return rateLimited + logger.info(`[${requestId}] Authenticated A2A get task request via ${authResult.authType}`, { userId: authResult.userId, }) diff --git a/apps/sim/app/api/tools/a2a/resubscribe/route.ts b/apps/sim/app/api/tools/a2a/resubscribe/route.ts index 69af495725b..198e0458224 100644 --- a/apps/sim/app/api/tools/a2a/resubscribe/route.ts +++ b/apps/sim/app/api/tools/a2a/resubscribe/route.ts @@ -12,6 +12,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aResubscribeContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -36,6 +37,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-resubscribe', authResult.userId!) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aResubscribeContract, request, diff --git a/apps/sim/app/api/tools/a2a/send-message/route.ts b/apps/sim/app/api/tools/a2a/send-message/route.ts index badd3267f3d..7cbc07efd44 100644 --- a/apps/sim/app/api/tools/a2a/send-message/route.ts +++ b/apps/sim/app/api/tools/a2a/send-message/route.ts @@ -7,6 +7,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aSendMessageContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -32,6 +33,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-send-message', authResult.userId!) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A send message request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts index de9c41b8ccc..46d89b0ffa2 100644 --- a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aSetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -31,6 +32,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserRateLimit('a2a-set-push-notification', authResult.userId!) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aSetPushNotificationContract, request, diff --git a/apps/sim/app/api/users/me/settings/unsubscribe/route.ts b/apps/sim/app/api/users/me/settings/unsubscribe/route.ts index b2806563ede..654324e85d4 100644 --- a/apps/sim/app/api/users/me/settings/unsubscribe/route.ts +++ b/apps/sim/app/api/users/me/settings/unsubscribe/route.ts @@ -6,6 +6,7 @@ import { unsubscribePostContract, } from '@/lib/api/contracts/user' import { parseRequest } from '@/lib/api/server' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import type { EmailType } from '@/lib/messaging/email/mailer' @@ -19,9 +20,18 @@ import { const logger = createLogger('UnsubscribeAPI') +const UNSUBSCRIBE_RATE_LIMIT = { + maxTokens: 10, + refillRate: 10, + refillIntervalMs: 60_000, +} + export const GET = withRouteHandler(async (req: NextRequest) => { const requestId = generateRequestId() + const rateLimited = await enforceIpRateLimit('unsubscribe', req, UNSUBSCRIBE_RATE_LIMIT) + if (rateLimited) return rateLimited + try { const parsed = await parseRequest( unsubscribeGetContract, @@ -70,6 +80,9 @@ export const GET = withRouteHandler(async (req: NextRequest) => { export const POST = withRouteHandler(async (req: NextRequest) => { const requestId = generateRequestId() + const rateLimited = await enforceIpRateLimit('unsubscribe', req, UNSUBSCRIBE_RATE_LIMIT) + if (rateLimited) return rateLimited + try { const contentType = req.headers.get('content-type') || '' diff --git a/apps/sim/app/api/v1/copilot/chat/route.ts b/apps/sim/app/api/v1/copilot/chat/route.ts index d513318728a..6eaa1424b77 100644 --- a/apps/sim/app/api/v1/copilot/chat/route.ts +++ b/apps/sim/app/api/v1/copilot/chat/route.ts @@ -7,7 +7,7 @@ import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { runHeadlessCopilotLifecycle } from '@/lib/copilot/request/lifecycle/headless' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { getWorkflowById, resolveWorkflowIdForUser } from '@/lib/workflows/utils' -import { authenticateV1Request } from '@/app/api/v1/auth' +import { authenticateRequest } from '@/app/api/v1/middleware' export const maxDuration = 3600 @@ -25,12 +25,16 @@ const DEFAULT_COPILOT_MODEL = 'claude-opus-4-6' */ export const POST = withRouteHandler(async (req: NextRequest) => { let messageId: string | undefined - const auth = await authenticateV1Request(req) - if (!auth.authenticated || !auth.userId) { - return NextResponse.json( - { success: false, error: auth.error || 'Unauthorized' }, - { status: 401 } - ) + const authorized = await authenticateRequest(req, 'copilot-chat') + if (authorized instanceof NextResponse) { + return authorized + } + const { userId, rateLimit } = authorized + const auth = { + authenticated: true as const, + userId, + keyType: rateLimit.keyType, + workspaceId: rateLimit.workspaceId, } try { diff --git a/apps/sim/app/api/v1/middleware.ts b/apps/sim/app/api/v1/middleware.ts index dcb427aa3a1..92aa72eb344 100644 --- a/apps/sim/app/api/v1/middleware.ts +++ b/apps/sim/app/api/v1/middleware.ts @@ -26,6 +26,7 @@ export type V1Endpoint = | 'knowledge' | 'knowledge-detail' | 'knowledge-search' + | 'copilot-chat' export interface RateLimitResult { allowed: boolean diff --git a/apps/sim/lib/api/contracts/v1/copilot.ts b/apps/sim/lib/api/contracts/v1/copilot.ts index de85c0f7409..f09626e41a2 100644 --- a/apps/sim/lib/api/contracts/v1/copilot.ts +++ b/apps/sim/lib/api/contracts/v1/copilot.ts @@ -10,7 +10,7 @@ export const v1CopilotChatBodySchema = z.object({ mode: z.enum(COPILOT_REQUEST_MODES).optional().default('agent'), model: z.string().optional(), autoExecuteTools: z.boolean().optional().default(true), - timeout: z.number().optional().default(3_600_000), + timeout: z.number().int().min(1000).max(3_600_000).optional().default(3_600_000), }) export type V1CopilotChatBody = z.output diff --git a/apps/sim/lib/core/rate-limiter/index.ts b/apps/sim/lib/core/rate-limiter/index.ts index 4ecc48f2b5a..2345b27aede 100644 --- a/apps/sim/lib/core/rate-limiter/index.ts +++ b/apps/sim/lib/core/rate-limiter/index.ts @@ -8,6 +8,12 @@ export { toTokenBucketConfig, } from './hosted-key' export { RateLimiter } from './rate-limiter' +export { + DEFAULT_PUBLIC_IP_ROUTE_LIMIT, + DEFAULT_USER_ROUTE_LIMIT, + enforceIpRateLimit, + enforceUserRateLimit, +} from './route-helpers' export type { TokenBucketConfig } from './storage' export type { SubscriptionPlan } from './types' export { getRateLimit, RATE_LIMITS, RateLimitError } from './types' diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.test.ts b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts new file mode 100644 index 00000000000..10c237b7807 --- /dev/null +++ b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts @@ -0,0 +1,160 @@ +/** + * @vitest-environment node + */ +import { createMockRequest, requestUtilsMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +const { mockAdapter } = vi.hoisted(() => ({ + mockAdapter: { + consumeTokens: vi.fn(), + getTokenStatus: vi.fn(), + resetBucket: vi.fn(), + }, +})) + +vi.mock('@/lib/core/rate-limiter/storage', async () => { + const actual = await vi.importActual( + '@/lib/core/rate-limiter/storage' + ) + return { + ...actual, + createStorageAdapter: () => mockAdapter, + } +}) + +function passThroughClientIp() { + requestUtilsMockFns.mockGetClientIp.mockImplementation( + (req: { headers: { get(name: string): string | null } }) => + req.headers.get('x-forwarded-for')?.split(',')[0]?.trim() || + req.headers.get('x-real-ip')?.trim() || + 'unknown' + ) +} + +import { enforceIpRateLimit, enforceUserRateLimit } from './route-helpers' + +const consume = mockAdapter.consumeTokens as Mock + +describe('route-helpers rate limiting', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('enforceUserRateLimit', () => { + it('returns null when the bucket has tokens left', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(Date.now() + 60_000), + }) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).toBeNull() + expect(consume).toHaveBeenCalledWith( + 'route:test-bucket:user:user-1', + 1, + expect.objectContaining({ maxTokens: 60, refillRate: 30 }) + ) + }) + + it('returns a 429 with Retry-After when the bucket is empty', async () => { + const resetAt = new Date(Date.now() + 30_000) + consume.mockResolvedValueOnce({ + allowed: false, + tokensRemaining: 0, + resetAt, + retryAfterMs: 30_000, + }) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).not.toBeNull() + expect(result?.status).toBe(429) + expect(result?.headers.get('Retry-After')).toBe('30') + expect(result?.headers.get('X-RateLimit-Reset')).toBe(resetAt.toISOString()) + + const body = await result?.json() + expect(body?.error).toBe('Rate limit exceeded') + }) + + it('keys buckets per user so different users do not share state', async () => { + consume.mockResolvedValue({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + + await enforceUserRateLimit('shared-bucket', 'user-a') + await enforceUserRateLimit('shared-bucket', 'user-b') + + const keys = consume.mock.calls.map((call) => call[0]) + expect(keys).toEqual(['route:shared-bucket:user:user-a', 'route:shared-bucket:user:user-b']) + }) + + it('fails open when the storage layer throws', async () => { + consume.mockRejectedValueOnce(new Error('redis down')) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).toBeNull() + }) + }) + + describe('enforceIpRateLimit', () => { + beforeEach(() => { + passThroughClientIp() + }) + + it('uses the X-Forwarded-For client IP in the bucket key', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 9, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { + 'x-forwarded-for': '203.0.113.7, 10.0.0.1', + }) + + await enforceIpRateLimit('public-bucket', request) + + expect(consume).toHaveBeenCalledWith( + 'route:public-bucket:ip:203.0.113.7', + 1, + expect.any(Object) + ) + }) + + it('folds spoofed `X-Forwarded-For: unknown` into a single shared bucket', async () => { + consume.mockResolvedValue({ + allowed: true, + tokensRemaining: 9, + resetAt: new Date(), + }) + + const reqA = createMockRequest('POST', undefined, { 'x-forwarded-for': 'unknown' }) + const reqB = createMockRequest('POST', undefined, { 'x-forwarded-for': 'unknown' }) + await enforceIpRateLimit('otp', reqA) + await enforceIpRateLimit('otp', reqB) + + const keys = consume.mock.calls.map((call) => call[0]) + expect(keys).toEqual(['route:otp:ip:unknown', 'route:otp:ip:unknown']) + }) + + it('returns a 429 with Retry-After on rate limit', async () => { + const resetAt = new Date(Date.now() + 60_000) + consume.mockResolvedValueOnce({ + allowed: false, + tokensRemaining: 0, + resetAt, + retryAfterMs: 60_000, + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + const result = await enforceIpRateLimit('public-bucket', request) + + expect(result?.status).toBe(429) + expect(result?.headers.get('Retry-After')).toBe('60') + }) + }) +}) diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.ts b/apps/sim/lib/core/rate-limiter/route-helpers.ts new file mode 100644 index 00000000000..124b5168953 --- /dev/null +++ b/apps/sim/lib/core/rate-limiter/route-helpers.ts @@ -0,0 +1,74 @@ +import { createLogger } from '@sim/logger' +import { type NextRequest, NextResponse } from 'next/server' +import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter' +import type { TokenBucketConfig } from '@/lib/core/rate-limiter/storage' +import { getClientIp } from '@/lib/core/utils/request' + +const logger = createLogger('RouteRateLimit') +const rateLimiter = new RateLimiter() + +/** Default per-user bucket for authenticated tool routes (60 burst, 30/min). */ +export const DEFAULT_USER_ROUTE_LIMIT: TokenBucketConfig = { + maxTokens: 60, + refillRate: 30, + refillIntervalMs: 60_000, +} + +/** Default per-IP bucket for unauthenticated public endpoints (10 burst, 5/min). */ +export const DEFAULT_PUBLIC_IP_ROUTE_LIMIT: TokenBucketConfig = { + maxTokens: 10, + refillRate: 5, + refillIntervalMs: 60_000, +} + +function buildRateLimitResponse(resetAt: Date): NextResponse { + const retryAfterSec = Math.max(1, Math.ceil((resetAt.getTime() - Date.now()) / 1000)) + return NextResponse.json( + { + success: false, + error: 'Rate limit exceeded', + retryAfter: resetAt.getTime(), + }, + { + status: 429, + headers: { + 'Retry-After': String(retryAfterSec), + 'X-RateLimit-Reset': resetAt.toISOString(), + }, + } + ) +} + +/** + * Apply a per-user token bucket to an authenticated route. + * Returns a `NextResponse` on 429, otherwise `null` so the caller can proceed. + */ +export async function enforceUserRateLimit( + bucketName: string, + userId: string, + config: TokenBucketConfig = DEFAULT_USER_ROUTE_LIMIT +): Promise { + const key = `route:${bucketName}:user:${userId}` + const { allowed, resetAt } = await rateLimiter.checkRateLimitDirect(key, config) + if (allowed) return null + logger.warn('User rate limit exceeded', { bucket: bucketName, userId }) + return buildRateLimitResponse(resetAt) +} + +/** + * Apply a per-IP token bucket to an unauthenticated route. The `unknown` IP + * fallback shares one global bucket per route so it cannot be amplified by + * `X-Forwarded-For: unknown` spoofing. + */ +export async function enforceIpRateLimit( + bucketName: string, + request: NextRequest, + config: TokenBucketConfig = DEFAULT_PUBLIC_IP_ROUTE_LIMIT +): Promise { + const ip = getClientIp(request) + const key = `route:${bucketName}:ip:${ip}` + const { allowed, resetAt } = await rateLimiter.checkRateLimitDirect(key, config) + if (allowed) return null + logger.warn('IP rate limit exceeded', { bucket: bucketName, ip }) + return buildRateLimitResponse(resetAt) +} From 2f3657cd373a60e893c28cd613588fd3d86d72cc Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Wed, 13 May 2026 19:02:53 -0700 Subject: [PATCH 2/2] =?UTF-8?q?fix(rate-limit):=20address=20PR=20review=20?= =?UTF-8?q?=E2=80=94=20drop=20success=20field=20from=20429=20body,=20fall?= =?UTF-8?q?=20back=20to=20per-IP=20when=20JWT=20auth=20lacks=20userId?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/api/tools/a2a/cancel-task/route.ts | 8 +++-- .../a2a/delete-push-notification/route.ts | 7 ++-- .../app/api/tools/a2a/get-agent-card/route.ts | 8 +++-- .../tools/a2a/get-push-notification/route.ts | 8 +++-- apps/sim/app/api/tools/a2a/get-task/route.ts | 4 +-- .../app/api/tools/a2a/resubscribe/route.ts | 8 +++-- .../app/api/tools/a2a/send-message/route.ts | 8 +++-- .../tools/a2a/set-push-notification/route.ts | 8 +++-- apps/sim/lib/core/rate-limiter/index.ts | 1 + .../core/rate-limiter/route-helpers.test.ts | 34 ++++++++++++++++++- .../lib/core/rate-limiter/route-helpers.ts | 17 +++++++++- 11 files changed, 92 insertions(+), 19 deletions(-) diff --git a/apps/sim/app/api/tools/a2a/cancel-task/route.ts b/apps/sim/app/api/tools/a2a/cancel-task/route.ts index 188303382a7..92935a001a0 100644 --- a/apps/sim/app/api/tools/a2a/cancel-task/route.ts +++ b/apps/sim/app/api/tools/a2a/cancel-task/route.ts @@ -5,7 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aCancelTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,7 +30,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-cancel-task', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-cancel-task', + authResult.userId, + request + ) if (rateLimited) return rateLimited const parsed = await parseRequest( diff --git a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts index 1ae6ccfde6f..cf93e9e2f36 100644 --- a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts @@ -4,7 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aDeletePushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -31,9 +31,10 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit( + const rateLimited = await enforceUserOrIpRateLimit( 'a2a-delete-push-notification', - authResult.userId! + authResult.userId, + request ) if (rateLimited) return rateLimited diff --git a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts index 513079aeaa7..fed318b8330 100644 --- a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts +++ b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts @@ -4,7 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetAgentCardContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -29,7 +29,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-get-agent-card', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-get-agent-card', + authResult.userId, + request + ) if (rateLimited) return rateLimited logger.info( diff --git a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts index 4cbd2baf84b..6c48da2648c 100644 --- a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts @@ -4,7 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -31,7 +31,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-get-push-notification', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-get-push-notification', + authResult.userId, + request + ) if (rateLimited) return rateLimited logger.info( diff --git a/apps/sim/app/api/tools/a2a/get-task/route.ts b/apps/sim/app/api/tools/a2a/get-task/route.ts index dab4d57de98..3e38b82f80c 100644 --- a/apps/sim/app/api/tools/a2a/get-task/route.ts +++ b/apps/sim/app/api/tools/a2a/get-task/route.ts @@ -5,7 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,7 +30,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-get-task', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit('a2a-get-task', authResult.userId, request) if (rateLimited) return rateLimited logger.info(`[${requestId}] Authenticated A2A get task request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/resubscribe/route.ts b/apps/sim/app/api/tools/a2a/resubscribe/route.ts index 198e0458224..bd4bdebabc7 100644 --- a/apps/sim/app/api/tools/a2a/resubscribe/route.ts +++ b/apps/sim/app/api/tools/a2a/resubscribe/route.ts @@ -12,7 +12,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aResubscribeContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -37,7 +37,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-resubscribe', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-resubscribe', + authResult.userId, + request + ) if (rateLimited) return rateLimited const parsed = await parseRequest( diff --git a/apps/sim/app/api/tools/a2a/send-message/route.ts b/apps/sim/app/api/tools/a2a/send-message/route.ts index 7cbc07efd44..708863a8715 100644 --- a/apps/sim/app/api/tools/a2a/send-message/route.ts +++ b/apps/sim/app/api/tools/a2a/send-message/route.ts @@ -7,7 +7,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aSendMessageContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -33,7 +33,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-send-message', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-send-message', + authResult.userId, + request + ) if (rateLimited) return rateLimited logger.info( diff --git a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts index 46d89b0ffa2..5511da2d2cc 100644 --- a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts @@ -4,7 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aSetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' -import { enforceUserRateLimit } from '@/lib/core/rate-limiter' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -32,7 +32,11 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const rateLimited = await enforceUserRateLimit('a2a-set-push-notification', authResult.userId!) + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-set-push-notification', + authResult.userId, + request + ) if (rateLimited) return rateLimited const parsed = await parseRequest( diff --git a/apps/sim/lib/core/rate-limiter/index.ts b/apps/sim/lib/core/rate-limiter/index.ts index 2345b27aede..9afd3cb231f 100644 --- a/apps/sim/lib/core/rate-limiter/index.ts +++ b/apps/sim/lib/core/rate-limiter/index.ts @@ -12,6 +12,7 @@ export { DEFAULT_PUBLIC_IP_ROUTE_LIMIT, DEFAULT_USER_ROUTE_LIMIT, enforceIpRateLimit, + enforceUserOrIpRateLimit, enforceUserRateLimit, } from './route-helpers' export type { TokenBucketConfig } from './storage' diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.test.ts b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts index 10c237b7807..0f895e81e1a 100644 --- a/apps/sim/lib/core/rate-limiter/route-helpers.test.ts +++ b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts @@ -31,7 +31,7 @@ function passThroughClientIp() { ) } -import { enforceIpRateLimit, enforceUserRateLimit } from './route-helpers' +import { enforceIpRateLimit, enforceUserOrIpRateLimit, enforceUserRateLimit } from './route-helpers' const consume = mockAdapter.consumeTokens as Mock @@ -157,4 +157,36 @@ describe('route-helpers rate limiting', () => { expect(result?.headers.get('Retry-After')).toBe('60') }) }) + + describe('enforceUserOrIpRateLimit', () => { + beforeEach(() => { + passThroughClientIp() + }) + + it('keys per-user when userId is present', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + await enforceUserOrIpRateLimit('a2a-test', 'user-1', request) + + expect(consume).toHaveBeenCalledWith('route:a2a-test:user:user-1', 1, expect.any(Object)) + }) + + it('falls back to per-IP when userId is undefined', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + await enforceUserOrIpRateLimit('a2a-test', undefined, request) + + expect(consume).toHaveBeenCalledWith('route:a2a-test:ip:203.0.113.7', 1, expect.any(Object)) + }) + }) }) diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.ts b/apps/sim/lib/core/rate-limiter/route-helpers.ts index 124b5168953..8b7cae9dc1a 100644 --- a/apps/sim/lib/core/rate-limiter/route-helpers.ts +++ b/apps/sim/lib/core/rate-limiter/route-helpers.ts @@ -25,7 +25,6 @@ function buildRateLimitResponse(resetAt: Date): NextResponse { const retryAfterSec = Math.max(1, Math.ceil((resetAt.getTime() - Date.now()) / 1000)) return NextResponse.json( { - success: false, error: 'Rate limit exceeded', retryAfter: resetAt.getTime(), }, @@ -72,3 +71,19 @@ export async function enforceIpRateLimit( logger.warn('IP rate limit exceeded', { bucket: bucketName, ip }) return buildRateLimitResponse(resetAt) } + +/** + * Apply a per-user limit when a userId is present, else fall back to per-IP. + * Use for routes whose auth path may legitimately resolve without a userId + * (e.g. internal JWT calls with `requireWorkflowId: false`) so missing-userId + * traffic is still throttled per-IP rather than sharing one global bucket. + */ +export async function enforceUserOrIpRateLimit( + bucketName: string, + userId: string | undefined, + request: NextRequest, + config: TokenBucketConfig = DEFAULT_USER_ROUTE_LIMIT +): Promise { + if (userId) return enforceUserRateLimit(bucketName, userId, config) + return enforceIpRateLimit(bucketName, request, config) +}