From 568a552d673aa685d1370eacf63172d0b6063a32 Mon Sep 17 00:00:00 2001 From: Waleed Date: Wed, 13 May 2026 19:16:57 -0700 Subject: [PATCH 1/5] fix(rate-limit): close rate-limit bypass and tighten public route limits (#4591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(rate-limit): close rate-limit bypass and tighten public route limits * fix(rate-limit): address PR review — drop success field from 429 body, fall back to per-IP when JWT auth lacks userId --- apps/sim/app/api/auth/socket-token/route.ts | 12 +- apps/sim/app/api/auth/sso/providers/route.ts | 9 + .../api/chat/[identifier]/otp/route.test.ts | 9 +- .../app/api/chat/[identifier]/otp/route.ts | 24 +-- .../app/api/chat/[identifier]/sso/route.ts | 24 +-- apps/sim/app/api/telemetry/route.ts | 8 + apps/sim/app/api/templates/[id]/route.ts | 51 +++-- .../app/api/tools/a2a/cancel-task/route.ts | 8 + .../a2a/delete-push-notification/route.ts | 8 + .../app/api/tools/a2a/get-agent-card/route.ts | 8 + .../tools/a2a/get-push-notification/route.ts | 8 + apps/sim/app/api/tools/a2a/get-task/route.ts | 4 + .../app/api/tools/a2a/resubscribe/route.ts | 8 + .../app/api/tools/a2a/send-message/route.ts | 8 + .../tools/a2a/set-push-notification/route.ts | 8 + .../users/me/settings/unsubscribe/route.ts | 13 ++ apps/sim/app/api/v1/copilot/chat/route.ts | 18 +- apps/sim/app/api/v1/middleware.ts | 1 + apps/sim/lib/api/contracts/v1/copilot.ts | 2 +- apps/sim/lib/core/rate-limiter/index.ts | 7 + .../core/rate-limiter/route-helpers.test.ts | 192 ++++++++++++++++++ .../lib/core/rate-limiter/route-helpers.ts | 89 ++++++++ 22 files changed, 466 insertions(+), 53 deletions(-) create mode 100644 apps/sim/lib/core/rate-limiter/route-helpers.test.ts create mode 100644 apps/sim/lib/core/rate-limiter/route-helpers.ts diff --git a/apps/sim/app/api/auth/socket-token/route.ts b/apps/sim/app/api/auth/socket-token/route.ts index 45151d1e496..c7b0dc618c8 100644 --- a/apps/sim/app/api/auth/socket-token/route.ts +++ b/apps/sim/app/api/auth/socket-token/route.ts @@ -1,18 +1,26 @@ import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { headers } from 'next/headers' -import { NextResponse } from 'next/server' +import { type NextRequest, NextResponse } from 'next/server' import { auth } from '@/lib/auth' import { isAuthDisabled } from '@/lib/core/config/feature-flags' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' const logger = createLogger('SocketTokenAPI') -export const POST = withRouteHandler(async () => { +export const POST = withRouteHandler(async (request: NextRequest) => { if (isAuthDisabled) { return NextResponse.json({ token: 'anonymous-socket-token' }) } + const rateLimited = await enforceIpRateLimit('socket-token', request, { + maxTokens: 30, + refillRate: 30, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + try { const hdrs = await headers() const response = await auth.api.generateOneTimeToken({ diff --git a/apps/sim/app/api/auth/sso/providers/route.ts b/apps/sim/app/api/auth/sso/providers/route.ts index 77f86734420..8428eebc1e1 100644 --- a/apps/sim/app/api/auth/sso/providers/route.ts +++ b/apps/sim/app/api/auth/sso/providers/route.ts @@ -5,6 +5,7 @@ import { type NextRequest, NextResponse } from 'next/server' import { listSsoProvidersContract } from '@/lib/api/contracts/auth' import { parseRequest } from '@/lib/api/server' import { getSession } from '@/lib/auth' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { REDACTED_MARKER } from '@/lib/core/security/redaction' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -13,6 +14,14 @@ const logger = createLogger('SSOProvidersRoute') export const GET = withRouteHandler(async (request: NextRequest) => { try { const session = await getSession() + if (!session?.user?.id) { + const rateLimited = await enforceIpRateLimit('sso-providers', request, { + maxTokens: 20, + refillRate: 20, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + } const parsed = await parseRequest(listSsoProvidersContract, request, {}) if (!parsed.success) return parsed.response const { organizationId } = parsed.data.query diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts index acd6652bf5d..547a164b069 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts @@ -423,7 +423,7 @@ describe('Chat OTP API Route', () => { expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') }) - it('skips IP rate limit when client IP is unknown', async () => { + it('folds spoofed `unknown` client IPs into a single shared bucket', async () => { requestUtilsMockFns.mockGetClientIp.mockReturnValueOnce('unknown') buildDeploymentSelect() @@ -434,8 +434,11 @@ describe('Chat OTP API Route', () => { await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) - // Only the email-scoped check should run, not the IP-scoped one - expect(mockCheckRateLimitDirect).toHaveBeenCalledTimes(1) + expect(mockCheckRateLimitDirect).toHaveBeenCalledTimes(2) + expect(mockCheckRateLimitDirect).toHaveBeenCalledWith( + expect.stringMatching(/^chat-otp:ip:.*:unknown$/), + expect.any(Object) + ) expect(mockCheckRateLimitDirect).toHaveBeenCalledWith( expect.stringContaining('chat-otp:email:'), expect.any(Object) diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index d546ef0cf77..b2e129b5fa8 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -223,20 +223,18 @@ export const POST = withRouteHandler( try { const ip = getClientIp(request) - if (ip !== 'unknown') { - const ipRateLimit = await rateLimiter.checkRateLimitDirect( - `chat-otp:ip:${identifier}:${ip}`, - OTP_IP_RATE_LIMIT + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `chat-otp:ip:${identifier}:${ip}`, + OTP_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 ) - if (!ipRateLimit.allowed) { - logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) - const retryAfter = Math.ceil( - (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 - ) - const response = createErrorResponse('Too many requests. Please try again later.', 429) - response.headers.set('Retry-After', String(retryAfter)) - return addCorsHeaders(response, request) - } + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) } const parsed = await parseRequest(requestChatEmailOtpContract, request, context, { diff --git a/apps/sim/app/api/chat/[identifier]/sso/route.ts b/apps/sim/app/api/chat/[identifier]/sso/route.ts index 812f27df5b3..c6878876a90 100644 --- a/apps/sim/app/api/chat/[identifier]/sso/route.ts +++ b/apps/sim/app/api/chat/[identifier]/sso/route.ts @@ -30,20 +30,18 @@ export const POST = withRouteHandler( const requestId = generateRequestId() const ip = getClientIp(request) - if (ip !== 'unknown') { - const ipRateLimit = await rateLimiter.checkRateLimitDirect( - `chat-sso:ip:${ip}`, - SSO_IP_RATE_LIMIT + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `chat-sso:ip:${ip}`, + SSO_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] SSO eligibility rate limit exceeded from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? SSO_IP_RATE_LIMIT.refillIntervalMs) / 1000 ) - if (!ipRateLimit.allowed) { - logger.warn(`[${requestId}] SSO eligibility rate limit exceeded from ${ip}`) - const retryAfter = Math.ceil( - (ipRateLimit.retryAfterMs ?? SSO_IP_RATE_LIMIT.refillIntervalMs) / 1000 - ) - const response = createErrorResponse('Too many requests. Please try again later.', 429) - response.headers.set('Retry-After', String(retryAfter)) - return addCorsHeaders(response, request) - } + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) } const parsed = await parseRequest(chatSSOContract, request, context) diff --git a/apps/sim/app/api/telemetry/route.ts b/apps/sim/app/api/telemetry/route.ts index bdeb0a6b109..aed019188cc 100644 --- a/apps/sim/app/api/telemetry/route.ts +++ b/apps/sim/app/api/telemetry/route.ts @@ -4,6 +4,7 @@ import { telemetryContract } from '@/lib/api/contracts/telemetry' import { parseRequest } from '@/lib/api/server' import { env } from '@/lib/core/config/env' import { isProd } from '@/lib/core/config/feature-flags' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' const logger = createLogger('TelemetryAPI') @@ -148,6 +149,13 @@ async function forwardToCollector(data: Record): Promise { + const rateLimited = await enforceIpRateLimit('telemetry', req, { + maxTokens: 60, + refillRate: 30, + refillIntervalMs: 60_000, + }) + if (rateLimited) return rateLimited + try { const parsed = await parseRequest(telemetryContract, req, {}) if (!parsed.success) return parsed.response diff --git a/apps/sim/app/api/templates/[id]/route.ts b/apps/sim/app/api/templates/[id]/route.ts index f0bbd4f0d16..bb2e4a48c97 100644 --- a/apps/sim/app/api/templates/[id]/route.ts +++ b/apps/sim/app/api/templates/[id]/route.ts @@ -7,7 +7,8 @@ import { type NextRequest, NextResponse } from 'next/server' import { templateIdParamsSchema, updateTemplateContract } from '@/lib/api/contracts/templates' import { parseRequest } from '@/lib/api/server' import { getSession } from '@/lib/auth' -import { generateRequestId } from '@/lib/core/utils/request' +import { RateLimiter } from '@/lib/core/rate-limiter' +import { generateRequestId, getClientIp } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { canAccessTemplate } from '@/lib/templates/permissions' import { @@ -18,6 +19,18 @@ import type { WorkflowState } from '@/stores/workflows/workflow/types' const logger = createLogger('TemplateByIdAPI') +const viewRateLimiter = new RateLimiter() + +/** + * Per-IP, per-template view-counter dedup bucket: one increment per 10 minutes. + * Prevents scripted inflation of `templates.views` from the public GET handler. + */ +const TEMPLATE_VIEW_DEDUP = { + maxTokens: 1, + refillRate: 1, + refillIntervalMs: 10 * 60_000, +} + export const revalidate = 0 export const GET = withRouteHandler( @@ -63,21 +76,31 @@ export const GET = withRouteHandler( isStarred = starResult.length > 0 } - const shouldIncrementView = template.status === 'approved' + let shouldIncrementView = template.status === 'approved' if (shouldIncrementView) { - try { - await db - .update(templates) - .set({ - views: sql`${templates.views} + 1`, - }) - .where(eq(templates.id, id)) - } catch (viewError) { - logger.warn( - `[${requestId}] Failed to increment view count for template: ${id}`, - viewError - ) + const viewer = session?.user?.id ?? `ip:${getClientIp(request)}` + const dedupKey = `template-view:${id}:${viewer}` + const { allowed } = await viewRateLimiter.checkRateLimitDirect( + dedupKey, + TEMPLATE_VIEW_DEDUP + ) + if (!allowed) { + shouldIncrementView = false + } else { + try { + await db + .update(templates) + .set({ + views: sql`${templates.views} + 1`, + }) + .where(eq(templates.id, id)) + } catch (viewError) { + logger.warn( + `[${requestId}] Failed to increment view count for template: ${id}`, + viewError + ) + } } } diff --git a/apps/sim/app/api/tools/a2a/cancel-task/route.ts b/apps/sim/app/api/tools/a2a/cancel-task/route.ts index f5eef3170e7..92935a001a0 100644 --- a/apps/sim/app/api/tools/a2a/cancel-task/route.ts +++ b/apps/sim/app/api/tools/a2a/cancel-task/route.ts @@ -5,6 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aCancelTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -29,6 +30,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-cancel-task', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aCancelTaskContract, request, diff --git a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts index 77b46a30000..cf93e9e2f36 100644 --- a/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/delete-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aDeletePushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,6 +31,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-delete-push-notification', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A delete push notification request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts index 5a39e944d2f..fed318b8330 100644 --- a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts +++ b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetAgentCardContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -28,6 +29,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-get-agent-card', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A get agent card request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts index 49d21b07439..6c48da2648c 100644 --- a/apps/sim/app/api/tools/a2a/get-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/get-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -30,6 +31,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-get-push-notification', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A get push notification request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/get-task/route.ts b/apps/sim/app/api/tools/a2a/get-task/route.ts index ac21da72537..3e38b82f80c 100644 --- a/apps/sim/app/api/tools/a2a/get-task/route.ts +++ b/apps/sim/app/api/tools/a2a/get-task/route.ts @@ -5,6 +5,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aGetTaskContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -29,6 +30,9 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit('a2a-get-task', authResult.userId, request) + if (rateLimited) return rateLimited + logger.info(`[${requestId}] Authenticated A2A get task request via ${authResult.authType}`, { userId: authResult.userId, }) diff --git a/apps/sim/app/api/tools/a2a/resubscribe/route.ts b/apps/sim/app/api/tools/a2a/resubscribe/route.ts index 69af495725b..bd4bdebabc7 100644 --- a/apps/sim/app/api/tools/a2a/resubscribe/route.ts +++ b/apps/sim/app/api/tools/a2a/resubscribe/route.ts @@ -12,6 +12,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aResubscribeContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -36,6 +37,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-resubscribe', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aResubscribeContract, request, diff --git a/apps/sim/app/api/tools/a2a/send-message/route.ts b/apps/sim/app/api/tools/a2a/send-message/route.ts index badd3267f3d..708863a8715 100644 --- a/apps/sim/app/api/tools/a2a/send-message/route.ts +++ b/apps/sim/app/api/tools/a2a/send-message/route.ts @@ -7,6 +7,7 @@ import { createA2AClient, extractTextContent, isTerminalState } from '@/lib/a2a/ import { a2aSendMessageContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -32,6 +33,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-send-message', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + logger.info( `[${requestId}] Authenticated A2A send message request via ${authResult.authType}`, { diff --git a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts index de9c41b8ccc..5511da2d2cc 100644 --- a/apps/sim/app/api/tools/a2a/set-push-notification/route.ts +++ b/apps/sim/app/api/tools/a2a/set-push-notification/route.ts @@ -4,6 +4,7 @@ import { createA2AClient } from '@/lib/a2a/utils' import { a2aSetPushNotificationContract } from '@/lib/api/contracts/tools/a2a' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' +import { enforceUserOrIpRateLimit } from '@/lib/core/rate-limiter' import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' @@ -31,6 +32,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const rateLimited = await enforceUserOrIpRateLimit( + 'a2a-set-push-notification', + authResult.userId, + request + ) + if (rateLimited) return rateLimited + const parsed = await parseRequest( a2aSetPushNotificationContract, request, diff --git a/apps/sim/app/api/users/me/settings/unsubscribe/route.ts b/apps/sim/app/api/users/me/settings/unsubscribe/route.ts index b2806563ede..654324e85d4 100644 --- a/apps/sim/app/api/users/me/settings/unsubscribe/route.ts +++ b/apps/sim/app/api/users/me/settings/unsubscribe/route.ts @@ -6,6 +6,7 @@ import { unsubscribePostContract, } from '@/lib/api/contracts/user' import { parseRequest } from '@/lib/api/server' +import { enforceIpRateLimit } from '@/lib/core/rate-limiter' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import type { EmailType } from '@/lib/messaging/email/mailer' @@ -19,9 +20,18 @@ import { const logger = createLogger('UnsubscribeAPI') +const UNSUBSCRIBE_RATE_LIMIT = { + maxTokens: 10, + refillRate: 10, + refillIntervalMs: 60_000, +} + export const GET = withRouteHandler(async (req: NextRequest) => { const requestId = generateRequestId() + const rateLimited = await enforceIpRateLimit('unsubscribe', req, UNSUBSCRIBE_RATE_LIMIT) + if (rateLimited) return rateLimited + try { const parsed = await parseRequest( unsubscribeGetContract, @@ -70,6 +80,9 @@ export const GET = withRouteHandler(async (req: NextRequest) => { export const POST = withRouteHandler(async (req: NextRequest) => { const requestId = generateRequestId() + const rateLimited = await enforceIpRateLimit('unsubscribe', req, UNSUBSCRIBE_RATE_LIMIT) + if (rateLimited) return rateLimited + try { const contentType = req.headers.get('content-type') || '' diff --git a/apps/sim/app/api/v1/copilot/chat/route.ts b/apps/sim/app/api/v1/copilot/chat/route.ts index d513318728a..6eaa1424b77 100644 --- a/apps/sim/app/api/v1/copilot/chat/route.ts +++ b/apps/sim/app/api/v1/copilot/chat/route.ts @@ -7,7 +7,7 @@ import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { runHeadlessCopilotLifecycle } from '@/lib/copilot/request/lifecycle/headless' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { getWorkflowById, resolveWorkflowIdForUser } from '@/lib/workflows/utils' -import { authenticateV1Request } from '@/app/api/v1/auth' +import { authenticateRequest } from '@/app/api/v1/middleware' export const maxDuration = 3600 @@ -25,12 +25,16 @@ const DEFAULT_COPILOT_MODEL = 'claude-opus-4-6' */ export const POST = withRouteHandler(async (req: NextRequest) => { let messageId: string | undefined - const auth = await authenticateV1Request(req) - if (!auth.authenticated || !auth.userId) { - return NextResponse.json( - { success: false, error: auth.error || 'Unauthorized' }, - { status: 401 } - ) + const authorized = await authenticateRequest(req, 'copilot-chat') + if (authorized instanceof NextResponse) { + return authorized + } + const { userId, rateLimit } = authorized + const auth = { + authenticated: true as const, + userId, + keyType: rateLimit.keyType, + workspaceId: rateLimit.workspaceId, } try { diff --git a/apps/sim/app/api/v1/middleware.ts b/apps/sim/app/api/v1/middleware.ts index dcb427aa3a1..92aa72eb344 100644 --- a/apps/sim/app/api/v1/middleware.ts +++ b/apps/sim/app/api/v1/middleware.ts @@ -26,6 +26,7 @@ export type V1Endpoint = | 'knowledge' | 'knowledge-detail' | 'knowledge-search' + | 'copilot-chat' export interface RateLimitResult { allowed: boolean diff --git a/apps/sim/lib/api/contracts/v1/copilot.ts b/apps/sim/lib/api/contracts/v1/copilot.ts index de85c0f7409..f09626e41a2 100644 --- a/apps/sim/lib/api/contracts/v1/copilot.ts +++ b/apps/sim/lib/api/contracts/v1/copilot.ts @@ -10,7 +10,7 @@ export const v1CopilotChatBodySchema = z.object({ mode: z.enum(COPILOT_REQUEST_MODES).optional().default('agent'), model: z.string().optional(), autoExecuteTools: z.boolean().optional().default(true), - timeout: z.number().optional().default(3_600_000), + timeout: z.number().int().min(1000).max(3_600_000).optional().default(3_600_000), }) export type V1CopilotChatBody = z.output diff --git a/apps/sim/lib/core/rate-limiter/index.ts b/apps/sim/lib/core/rate-limiter/index.ts index 4ecc48f2b5a..9afd3cb231f 100644 --- a/apps/sim/lib/core/rate-limiter/index.ts +++ b/apps/sim/lib/core/rate-limiter/index.ts @@ -8,6 +8,13 @@ export { toTokenBucketConfig, } from './hosted-key' export { RateLimiter } from './rate-limiter' +export { + DEFAULT_PUBLIC_IP_ROUTE_LIMIT, + DEFAULT_USER_ROUTE_LIMIT, + enforceIpRateLimit, + enforceUserOrIpRateLimit, + enforceUserRateLimit, +} from './route-helpers' export type { TokenBucketConfig } from './storage' export type { SubscriptionPlan } from './types' export { getRateLimit, RATE_LIMITS, RateLimitError } from './types' diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.test.ts b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts new file mode 100644 index 00000000000..0f895e81e1a --- /dev/null +++ b/apps/sim/lib/core/rate-limiter/route-helpers.test.ts @@ -0,0 +1,192 @@ +/** + * @vitest-environment node + */ +import { createMockRequest, requestUtilsMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +const { mockAdapter } = vi.hoisted(() => ({ + mockAdapter: { + consumeTokens: vi.fn(), + getTokenStatus: vi.fn(), + resetBucket: vi.fn(), + }, +})) + +vi.mock('@/lib/core/rate-limiter/storage', async () => { + const actual = await vi.importActual( + '@/lib/core/rate-limiter/storage' + ) + return { + ...actual, + createStorageAdapter: () => mockAdapter, + } +}) + +function passThroughClientIp() { + requestUtilsMockFns.mockGetClientIp.mockImplementation( + (req: { headers: { get(name: string): string | null } }) => + req.headers.get('x-forwarded-for')?.split(',')[0]?.trim() || + req.headers.get('x-real-ip')?.trim() || + 'unknown' + ) +} + +import { enforceIpRateLimit, enforceUserOrIpRateLimit, enforceUserRateLimit } from './route-helpers' + +const consume = mockAdapter.consumeTokens as Mock + +describe('route-helpers rate limiting', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('enforceUserRateLimit', () => { + it('returns null when the bucket has tokens left', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(Date.now() + 60_000), + }) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).toBeNull() + expect(consume).toHaveBeenCalledWith( + 'route:test-bucket:user:user-1', + 1, + expect.objectContaining({ maxTokens: 60, refillRate: 30 }) + ) + }) + + it('returns a 429 with Retry-After when the bucket is empty', async () => { + const resetAt = new Date(Date.now() + 30_000) + consume.mockResolvedValueOnce({ + allowed: false, + tokensRemaining: 0, + resetAt, + retryAfterMs: 30_000, + }) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).not.toBeNull() + expect(result?.status).toBe(429) + expect(result?.headers.get('Retry-After')).toBe('30') + expect(result?.headers.get('X-RateLimit-Reset')).toBe(resetAt.toISOString()) + + const body = await result?.json() + expect(body?.error).toBe('Rate limit exceeded') + }) + + it('keys buckets per user so different users do not share state', async () => { + consume.mockResolvedValue({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + + await enforceUserRateLimit('shared-bucket', 'user-a') + await enforceUserRateLimit('shared-bucket', 'user-b') + + const keys = consume.mock.calls.map((call) => call[0]) + expect(keys).toEqual(['route:shared-bucket:user:user-a', 'route:shared-bucket:user:user-b']) + }) + + it('fails open when the storage layer throws', async () => { + consume.mockRejectedValueOnce(new Error('redis down')) + + const result = await enforceUserRateLimit('test-bucket', 'user-1') + + expect(result).toBeNull() + }) + }) + + describe('enforceIpRateLimit', () => { + beforeEach(() => { + passThroughClientIp() + }) + + it('uses the X-Forwarded-For client IP in the bucket key', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 9, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { + 'x-forwarded-for': '203.0.113.7, 10.0.0.1', + }) + + await enforceIpRateLimit('public-bucket', request) + + expect(consume).toHaveBeenCalledWith( + 'route:public-bucket:ip:203.0.113.7', + 1, + expect.any(Object) + ) + }) + + it('folds spoofed `X-Forwarded-For: unknown` into a single shared bucket', async () => { + consume.mockResolvedValue({ + allowed: true, + tokensRemaining: 9, + resetAt: new Date(), + }) + + const reqA = createMockRequest('POST', undefined, { 'x-forwarded-for': 'unknown' }) + const reqB = createMockRequest('POST', undefined, { 'x-forwarded-for': 'unknown' }) + await enforceIpRateLimit('otp', reqA) + await enforceIpRateLimit('otp', reqB) + + const keys = consume.mock.calls.map((call) => call[0]) + expect(keys).toEqual(['route:otp:ip:unknown', 'route:otp:ip:unknown']) + }) + + it('returns a 429 with Retry-After on rate limit', async () => { + const resetAt = new Date(Date.now() + 60_000) + consume.mockResolvedValueOnce({ + allowed: false, + tokensRemaining: 0, + resetAt, + retryAfterMs: 60_000, + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + const result = await enforceIpRateLimit('public-bucket', request) + + expect(result?.status).toBe(429) + expect(result?.headers.get('Retry-After')).toBe('60') + }) + }) + + describe('enforceUserOrIpRateLimit', () => { + beforeEach(() => { + passThroughClientIp() + }) + + it('keys per-user when userId is present', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + await enforceUserOrIpRateLimit('a2a-test', 'user-1', request) + + expect(consume).toHaveBeenCalledWith('route:a2a-test:user:user-1', 1, expect.any(Object)) + }) + + it('falls back to per-IP when userId is undefined', async () => { + consume.mockResolvedValueOnce({ + allowed: true, + tokensRemaining: 59, + resetAt: new Date(), + }) + const request = createMockRequest('POST', undefined, { 'x-forwarded-for': '203.0.113.7' }) + + await enforceUserOrIpRateLimit('a2a-test', undefined, request) + + expect(consume).toHaveBeenCalledWith('route:a2a-test:ip:203.0.113.7', 1, expect.any(Object)) + }) + }) +}) diff --git a/apps/sim/lib/core/rate-limiter/route-helpers.ts b/apps/sim/lib/core/rate-limiter/route-helpers.ts new file mode 100644 index 00000000000..8b7cae9dc1a --- /dev/null +++ b/apps/sim/lib/core/rate-limiter/route-helpers.ts @@ -0,0 +1,89 @@ +import { createLogger } from '@sim/logger' +import { type NextRequest, NextResponse } from 'next/server' +import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter' +import type { TokenBucketConfig } from '@/lib/core/rate-limiter/storage' +import { getClientIp } from '@/lib/core/utils/request' + +const logger = createLogger('RouteRateLimit') +const rateLimiter = new RateLimiter() + +/** Default per-user bucket for authenticated tool routes (60 burst, 30/min). */ +export const DEFAULT_USER_ROUTE_LIMIT: TokenBucketConfig = { + maxTokens: 60, + refillRate: 30, + refillIntervalMs: 60_000, +} + +/** Default per-IP bucket for unauthenticated public endpoints (10 burst, 5/min). */ +export const DEFAULT_PUBLIC_IP_ROUTE_LIMIT: TokenBucketConfig = { + maxTokens: 10, + refillRate: 5, + refillIntervalMs: 60_000, +} + +function buildRateLimitResponse(resetAt: Date): NextResponse { + const retryAfterSec = Math.max(1, Math.ceil((resetAt.getTime() - Date.now()) / 1000)) + return NextResponse.json( + { + error: 'Rate limit exceeded', + retryAfter: resetAt.getTime(), + }, + { + status: 429, + headers: { + 'Retry-After': String(retryAfterSec), + 'X-RateLimit-Reset': resetAt.toISOString(), + }, + } + ) +} + +/** + * Apply a per-user token bucket to an authenticated route. + * Returns a `NextResponse` on 429, otherwise `null` so the caller can proceed. + */ +export async function enforceUserRateLimit( + bucketName: string, + userId: string, + config: TokenBucketConfig = DEFAULT_USER_ROUTE_LIMIT +): Promise { + const key = `route:${bucketName}:user:${userId}` + const { allowed, resetAt } = await rateLimiter.checkRateLimitDirect(key, config) + if (allowed) return null + logger.warn('User rate limit exceeded', { bucket: bucketName, userId }) + return buildRateLimitResponse(resetAt) +} + +/** + * Apply a per-IP token bucket to an unauthenticated route. The `unknown` IP + * fallback shares one global bucket per route so it cannot be amplified by + * `X-Forwarded-For: unknown` spoofing. + */ +export async function enforceIpRateLimit( + bucketName: string, + request: NextRequest, + config: TokenBucketConfig = DEFAULT_PUBLIC_IP_ROUTE_LIMIT +): Promise { + const ip = getClientIp(request) + const key = `route:${bucketName}:ip:${ip}` + const { allowed, resetAt } = await rateLimiter.checkRateLimitDirect(key, config) + if (allowed) return null + logger.warn('IP rate limit exceeded', { bucket: bucketName, ip }) + return buildRateLimitResponse(resetAt) +} + +/** + * Apply a per-user limit when a userId is present, else fall back to per-IP. + * Use for routes whose auth path may legitimately resolve without a userId + * (e.g. internal JWT calls with `requireWorkflowId: false`) so missing-userId + * traffic is still throttled per-IP rather than sharing one global bucket. + */ +export async function enforceUserOrIpRateLimit( + bucketName: string, + userId: string | undefined, + request: NextRequest, + config: TokenBucketConfig = DEFAULT_USER_ROUTE_LIMIT +): Promise { + if (userId) return enforceUserRateLimit(bucketName, userId, config) + return enforceIpRateLimit(bucketName, request, config) +} From 104949bdc2b92b2d5e507c2b6c14b600eed0b696 Mon Sep 17 00:00:00 2001 From: Waleed Date: Wed, 13 May 2026 19:52:43 -0700 Subject: [PATCH 2/5] fix(tables): eliminate checkbox flicker on rapid cell toggle (#4592) * fix(tables): eliminate checkbox flicker on rapid cell toggle * fix(tables): symmetric guarded onSettled across row write mutations * fix(tables): merge only mutated keys in onSuccess to preserve concurrent optimistic patches --- apps/sim/hooks/queries/tables.ts | 37 ++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/apps/sim/hooks/queries/tables.ts b/apps/sim/hooks/queries/tables.ts index bc8a30840af..f7b94e8b9f9 100644 --- a/apps/sim/hooks/queries/tables.ts +++ b/apps/sim/hooks/queries/tables.ts @@ -85,6 +85,7 @@ export const tableKeys = { rowsRoot: (tableId: string) => [...tableKeys.detail(tableId), 'rows'] as const, infiniteRows: (tableId: string, paramsKey: string) => [...tableKeys.rowsRoot(tableId), 'infinite', paramsKey] as const, + rowWrites: (tableId: string) => [...tableKeys.rowsRoot(tableId), 'write'] as const, } type TableRowsParams = Omit & @@ -543,14 +544,15 @@ export function useUpdateTableRow({ workspaceId, tableId }: RowMutationContext) const queryClient = useQueryClient() return useMutation({ + mutationKey: tableKeys.rowWrites(tableId), mutationFn: async ({ rowId, data }: UpdateTableRowParams) => { return requestJson(updateTableRowContract, { params: { tableId, rowId }, body: { workspaceId, data: data as RowData }, }) }, - onMutate: ({ rowId, data }) => { - void queryClient.cancelQueries({ queryKey: tableKeys.rowsRoot(tableId) }) + onMutate: async ({ rowId, data }) => { + await queryClient.cancelQueries({ queryKey: tableKeys.rowsRoot(tableId) }) const previousQueries = queryClient.getQueriesData>({ queryKey: tableKeys.rowsRoot(tableId), @@ -573,6 +575,24 @@ export function useUpdateTableRow({ workspaceId, tableId }: RowMutationContext) return { previousQueries } }, + onSuccess: (response, { rowId, data: mutatedData }) => { + const serverRow = response.data.row + const mutatedKeys = Object.keys(mutatedData) + patchCachedRows(queryClient, tableId, (row) => { + if (row.id !== rowId) return row + const merged: RowData = { ...row.data } + for (const key of mutatedKeys) { + merged[key] = (serverRow.data as RowData)[key] + } + return { + ...row, + data: merged, + position: serverRow.position, + createdAt: serverRow.createdAt, + updatedAt: serverRow.updatedAt, + } + }) + }, onError: (error, _vars, context) => { if (context?.previousQueries) { for (const [queryKey, data] of context.previousQueries) { @@ -583,7 +603,9 @@ export function useUpdateTableRow({ workspaceId, tableId }: RowMutationContext) toast.error(error.message, { duration: 5000 }) }, onSettled: () => { - invalidateRowData(queryClient, tableId) + if (queryClient.isMutating({ mutationKey: tableKeys.rowWrites(tableId) }) === 1) { + invalidateRowData(queryClient, tableId) + } }, }) } @@ -599,6 +621,7 @@ export function useBatchUpdateTableRows({ workspaceId, tableId }: RowMutationCon const queryClient = useQueryClient() return useMutation({ + mutationKey: tableKeys.rowWrites(tableId), mutationFn: async ({ updates }: BatchUpdateTableRowsParams) => { return requestJson(batchUpdateTableRowsContract, { params: { tableId }, @@ -608,8 +631,8 @@ export function useBatchUpdateTableRows({ workspaceId, tableId }: RowMutationCon }, }) }, - onMutate: ({ updates }) => { - void queryClient.cancelQueries({ queryKey: tableKeys.rowsRoot(tableId) }) + onMutate: async ({ updates }) => { + await queryClient.cancelQueries({ queryKey: tableKeys.rowsRoot(tableId) }) const previousQueries = queryClient.getQueriesData>({ queryKey: tableKeys.rowsRoot(tableId), @@ -644,7 +667,9 @@ export function useBatchUpdateTableRows({ workspaceId, tableId }: RowMutationCon toast.error(error.message, { duration: 5000 }) }, onSettled: () => { - invalidateRowData(queryClient, tableId) + if (queryClient.isMutating({ mutationKey: tableKeys.rowWrites(tableId) }) === 1) { + invalidateRowData(queryClient, tableId) + } }, }) } From b5dba82ac9a0a1604f7fc9b64499655c5738f77b Mon Sep 17 00:00:00 2001 From: Waleed Date: Wed, 13 May 2026 23:39:59 -0700 Subject: [PATCH 3/5] improvement(db): reduce connection saturation and egress hotspots (#4594) * improvement(db): reduce connection saturation and egress hotspots * fix(vfs): preserve native content type in copilot SQL projection * fix(vfs): guard jsonb_array_elements against non-array contentBlocks --- apps/realtime/src/database/operations.ts | 2 +- .../app/api/mcp/servers/[id]/refresh/route.ts | 6 ++--- apps/sim/app/api/mcp/tools/stored/route.ts | 8 +++--- apps/sim/lib/copilot/vfs/workspace-vfs.ts | 27 ++++++++++++++++--- packages/db/db.ts | 2 +- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/apps/realtime/src/database/operations.ts b/apps/realtime/src/database/operations.ts index 38a98b14bb3..4904daccb8f 100644 --- a/apps/realtime/src/database/operations.ts +++ b/apps/realtime/src/database/operations.ts @@ -30,7 +30,7 @@ const socketDb = drizzle( prepare: false, idle_timeout: 10, connect_timeout: 20, - max: 30, + max: 15, onnotice: () => {}, }), { schema } diff --git a/apps/sim/app/api/mcp/servers/[id]/refresh/route.ts b/apps/sim/app/api/mcp/servers/[id]/refresh/route.ts index 9e15224ead0..7bab3fade1f 100644 --- a/apps/sim/app/api/mcp/servers/[id]/refresh/route.ts +++ b/apps/sim/app/api/mcp/servers/[id]/refresh/route.ts @@ -2,7 +2,7 @@ import { db } from '@sim/db' import { mcpServers, workflow, workflowBlocks } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { and, eq, isNull } from 'drizzle-orm' +import { and, eq, inArray, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { mcpServerIdParamsSchema } from '@/lib/api/contracts/mcp' import { validationErrorResponse } from '@/lib/api/server' @@ -77,13 +77,11 @@ async function syncToolSchemasToWorkflows( subBlocks: workflowBlocks.subBlocks, }) .from(workflowBlocks) - .where(eq(workflowBlocks.type, 'agent')) + .where(and(eq(workflowBlocks.type, 'agent'), inArray(workflowBlocks.workflowId, workflowIds))) const updatedWorkflowIds = new Set() for (const block of agentBlocks) { - if (!workflowIds.includes(block.workflowId)) continue - const subBlocks = block.subBlocks as Record | null if (!subBlocks) continue diff --git a/apps/sim/app/api/mcp/tools/stored/route.ts b/apps/sim/app/api/mcp/tools/stored/route.ts index 59fa5f5102f..3606e05115d 100644 --- a/apps/sim/app/api/mcp/tools/stored/route.ts +++ b/apps/sim/app/api/mcp/tools/stored/route.ts @@ -2,7 +2,7 @@ import { db } from '@sim/db' import { workflow, workflowBlocks } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { eq } from 'drizzle-orm' +import { and, eq, inArray } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { withMcpAuth } from '@/lib/mcp/middleware' @@ -33,13 +33,13 @@ export const GET = withRouteHandler( const agentBlocks = await db .select({ workflowId: workflowBlocks.workflowId, subBlocks: workflowBlocks.subBlocks }) .from(workflowBlocks) - .where(eq(workflowBlocks.type, 'agent')) + .where( + and(eq(workflowBlocks.type, 'agent'), inArray(workflowBlocks.workflowId, workflowIds)) + ) const storedTools: StoredMcpTool[] = [] for (const block of agentBlocks) { - if (!workflowMap.has(block.workflowId)) continue - const subBlocks = block.subBlocks as Record | null if (!subBlocks) continue diff --git a/apps/sim/lib/copilot/vfs/workspace-vfs.ts b/apps/sim/lib/copilot/vfs/workspace-vfs.ts index 6e5cd70bb7d..3db1683bd19 100644 --- a/apps/sim/lib/copilot/vfs/workspace-vfs.ts +++ b/apps/sim/lib/copilot/vfs/workspace-vfs.ts @@ -17,7 +17,7 @@ import { } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { and, desc, eq, isNotNull, isNull, ne } from 'drizzle-orm' +import { and, desc, eq, isNotNull, isNull, ne, sql } from 'drizzle-orm' import { listApiKeys } from '@/lib/api-key/service' import { buildWorkspaceMd, type WorkspaceMdData } from '@/lib/copilot/chat/workspace-context' import { extractDocumentStyle } from '@/lib/copilot/vfs/document-style' @@ -1157,7 +1157,27 @@ export class WorkspaceVFS { .select({ id: copilotChats.id, title: copilotChats.title, - messages: copilotChats.messages, + messageCount: sql`COALESCE(jsonb_array_length(${copilotChats.messages}), 0)`, + messages: sql`COALESCE(( + SELECT jsonb_agg( + jsonb_build_object( + 'role', m->>'role', + 'content', m->'content', + 'contentBlocks', COALESCE(( + SELECT jsonb_agg(jsonb_build_object('type', 'text', 'content', b->'content')) + FROM jsonb_array_elements( + CASE WHEN jsonb_typeof(m->'contentBlocks') = 'array' + THEN m->'contentBlocks' + ELSE '[]'::jsonb + END + ) AS b + WHERE b->>'type' = 'text' + ), '[]'::jsonb) + ) + ) + FROM jsonb_array_elements(${copilotChats.messages}) AS m + WHERE m->>'role' IN ('user', 'assistant') + ), '[]'::jsonb)`, createdAt: copilotChats.createdAt, updatedAt: copilotChats.updatedAt, }) @@ -1177,13 +1197,14 @@ export class WorkspaceVFS { const safeName = sanitizeName(title) const prefix = `tasks/${safeName}/` const messages = Array.isArray(task.messages) ? task.messages : [] + const messageCount = Number(task.messageCount) || 0 this.files.set( `${prefix}session.md`, serializeTaskSession({ id: task.id, title, - messageCount: messages.length, + messageCount, createdAt: task.createdAt, updatedAt: task.updatedAt, }) diff --git a/packages/db/db.ts b/packages/db/db.ts index 6868bbaeeb7..9e5597fb57b 100644 --- a/packages/db/db.ts +++ b/packages/db/db.ts @@ -11,7 +11,7 @@ const postgresClient = postgres(connectionString, { prepare: false, idle_timeout: 20, connect_timeout: 30, - max: 30, + max: 15, onnotice: () => {}, }) From b1a944317865c0951e35eab70a85da06d2eda0c8 Mon Sep 17 00:00:00 2001 From: Vikhyath Mondreti Date: Wed, 13 May 2026 23:52:32 -0700 Subject: [PATCH 4/5] improvement(billing): move overage calculations out of txes (#4595) * improvement(billing): move calc subscription overage out of tx * fix double billing risk * address comments * address comments * share timeout const --- apps/sim/lib/billing/constants.ts | 5 + .../lib/billing/organizations/membership.ts | 66 ++- .../sim/lib/billing/threshold-billing.test.ts | 528 ++++++++++++++++++ apps/sim/lib/billing/threshold-billing.ts | 378 +++++++++---- .../sim/lib/billing/webhooks/invoices.test.ts | 38 ++ apps/sim/lib/billing/webhooks/invoices.ts | 148 +++-- 6 files changed, 979 insertions(+), 184 deletions(-) create mode 100644 apps/sim/lib/billing/threshold-billing.test.ts diff --git a/apps/sim/lib/billing/constants.ts b/apps/sim/lib/billing/constants.ts index d9a3c391540..ce20115db8b 100644 --- a/apps/sim/lib/billing/constants.ts +++ b/apps/sim/lib/billing/constants.ts @@ -34,6 +34,11 @@ export const SEARCH_TOOL_COST = 0.01 */ export const DEFAULT_OVERAGE_THRESHOLD = 100 +/** + * Maximum time to wait on billing coordination row locks before retrying later. + */ +export const BILLING_LOCK_TIMEOUT_MS = 5_000 + /** * Available credit tiers. Each tier maps a credit amount to the underlying dollar cost. * 1 credit = $0.005, so credits = dollars * 200. diff --git a/apps/sim/lib/billing/organizations/membership.ts b/apps/sim/lib/billing/organizations/membership.ts index 93add34dff3..9551396e246 100644 --- a/apps/sim/lib/billing/organizations/membership.ts +++ b/apps/sim/lib/billing/organizations/membership.ts @@ -926,34 +926,6 @@ export async function removeUserFromOrganization( ) } - let capturedUsage = 0 - if (!skipBillingLogic) { - const [departingUserStats] = await tx - .select({ currentPeriodCost: userStats.currentPeriodCost }) - .from(userStats) - .where(eq(userStats.userId, userId)) - .limit(1) - - if (departingUserStats?.currentPeriodCost) { - const usage = toNumber(toDecimal(departingUserStats.currentPeriodCost)) - if (usage > 0) { - await tx - .update(organization) - .set({ - departedMemberUsage: sql`${organization.departedMemberUsage} + ${usage}`, - }) - .where(eq(organization.id, organizationId)) - - await tx - .update(userStats) - .set({ currentPeriodCost: '0' }) - .where(eq(userStats.userId, userId)) - - capturedUsage = usage - } - } - } - const [targetUser] = await tx .select({ email: user.email }) .from(user) @@ -979,7 +951,44 @@ export async function removeUserFromOrganization( .from(workspace) .where(eq(workspace.organizationId, organizationId)) + const captureDepartedUsage = async () => { + if (skipBillingLogic) return 0 + + await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, organizationId)) + .for('update') + .limit(1) + + const [departingUserStats] = await tx + .select({ currentPeriodCost: userStats.currentPeriodCost }) + .from(userStats) + .where(eq(userStats.userId, userId)) + .for('update') + .limit(1) + + const usage = toNumber(toDecimal(departingUserStats?.currentPeriodCost)) + if (usage <= 0) return 0 + + await tx + .update(organization) + .set({ + departedMemberUsage: sql`${organization.departedMemberUsage} + ${usage}`, + }) + .where(eq(organization.id, organizationId)) + + await tx + .update(userStats) + .set({ currentPeriodCost: '0' }) + .where(eq(userStats.userId, userId)) + + return usage + } + if (orgWorkspaces.length === 0) { + const capturedUsage = await captureDepartedUsage() + return { workspaceIdsToRevoke: [] as string[], usageCaptured: capturedUsage, @@ -1022,6 +1031,7 @@ export async function removeUserFromOrganization( workspaceIds, userId, }) + const capturedUsage = await captureDepartedUsage() return { workspaceIdsToRevoke: deletedPerms.map((row) => row.entityId), diff --git a/apps/sim/lib/billing/threshold-billing.test.ts b/apps/sim/lib/billing/threshold-billing.test.ts new file mode 100644 index 00000000000..042dabaebf9 --- /dev/null +++ b/apps/sim/lib/billing/threshold-billing.test.ts @@ -0,0 +1,528 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCalculateSubscriptionOverage, + mockComputeOrgOverageAmount, + mockDbSelect, + mockDbTransaction, + mockEnqueueOutboxEvent, + mockGetEffectiveBillingStatus, + mockGetHighestPrioritySubscription, + mockGetOrganizationSubscriptionUsable, + mockHasUsableSubscriptionAccess, + mockIsEnterprise, + mockIsFree, + mockIsOrgScopedSubscription, + mockIsOrganizationBillingBlocked, + mockTxExecute, + mockTxSelect, + mockTxStatsLimit, + mockTxUpdate, +} = vi.hoisted(() => ({ + mockCalculateSubscriptionOverage: vi.fn(), + mockComputeOrgOverageAmount: vi.fn(), + mockDbSelect: vi.fn(), + mockDbTransaction: vi.fn(), + mockEnqueueOutboxEvent: vi.fn(), + mockGetEffectiveBillingStatus: vi.fn(), + mockGetHighestPrioritySubscription: vi.fn(), + mockGetOrganizationSubscriptionUsable: vi.fn(), + mockHasUsableSubscriptionAccess: vi.fn(), + mockIsEnterprise: vi.fn(), + mockIsFree: vi.fn(), + mockIsOrgScopedSubscription: vi.fn(), + mockIsOrganizationBillingBlocked: vi.fn(), + mockTxExecute: vi.fn(), + mockTxSelect: vi.fn(), + mockTxStatsLimit: vi.fn(), + mockTxUpdate: vi.fn(), +})) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + transaction: mockDbTransaction, + }, +})) + +vi.mock('@sim/db/schema', () => ({ + member: { + organizationId: 'member.organizationId', + role: 'member.role', + userId: 'member.userId', + }, + organization: { + creditBalance: 'organization.creditBalance', + departedMemberUsage: 'organization.departedMemberUsage', + id: 'organization.id', + }, + subscription: { + id: 'subscription.id', + stripeCustomerId: 'subscription.stripeCustomerId', + }, + userStats: { + billedOverageThisPeriod: 'userStats.billedOverageThisPeriod', + creditBalance: 'userStats.creditBalance', + currentPeriodCost: 'userStats.currentPeriodCost', + lastPeriodCost: 'userStats.lastPeriodCost', + proPeriodCostSnapshot: 'userStats.proPeriodCostSnapshot', + proPeriodCostSnapshotAt: 'userStats.proPeriodCostSnapshotAt', + userId: 'userStats.userId', + }, +})) + +vi.mock('@/lib/billing/core/access', () => ({ + getEffectiveBillingStatus: mockGetEffectiveBillingStatus, + isOrganizationBillingBlocked: mockIsOrganizationBillingBlocked, +})) + +vi.mock('@/lib/billing/core/billing', () => ({ + calculateSubscriptionOverage: mockCalculateSubscriptionOverage, + computeOrgOverageAmount: mockComputeOrgOverageAmount, +})) + +vi.mock('@/lib/billing/core/subscription', () => ({ + getHighestPrioritySubscription: mockGetHighestPrioritySubscription, + getOrganizationSubscriptionUsable: mockGetOrganizationSubscriptionUsable, +})) + +vi.mock('@/lib/billing/plan-helpers', () => ({ + isEnterprise: mockIsEnterprise, + isFree: mockIsFree, +})) + +vi.mock('@/lib/billing/subscriptions/utils', () => ({ + hasUsableSubscriptionAccess: mockHasUsableSubscriptionAccess, + isOrgScopedSubscription: mockIsOrgScopedSubscription, +})) + +vi.mock('@/lib/billing/webhooks/outbox-handlers', () => ({ + OUTBOX_EVENT_TYPES: { + STRIPE_THRESHOLD_OVERAGE_INVOICE: 'stripe.threshold-overage-invoice', + }, +})) + +vi.mock('@/lib/core/config/env', () => ({ + env: {}, + envNumber: vi.fn((_value: string | undefined, fallback: number) => fallback), +})) + +vi.mock('@/lib/core/outbox/service', () => ({ + enqueueOutboxEvent: mockEnqueueOutboxEvent, +})) + +import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing' + +interface MockTx { + execute: typeof mockTxExecute + select: typeof mockTxSelect + update: typeof mockTxUpdate +} + +const userSubscription = { + id: 'sub-db-1', + plan: 'pro', + referenceId: 'user-1', + seats: 1, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_stripe_1', + status: 'active', +} + +function buildSelectChain(rows: T[]) { + const chain = { + from: vi.fn(() => chain), + leftJoin: vi.fn(() => chain), + innerJoin: vi.fn(() => chain), + where: vi.fn(() => result), + } + const result = { + limit: vi.fn(async () => rows), + then: (resolve: (value: T[]) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(rows).then(resolve, reject), + } + + return { + from: chain.from, + } +} + +function buildPersonalSelectChain(customerId = 'cus_1') { + return buildSelectChain([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + stripeCustomerId: customerId, + }, + ]) +} + +function buildPersonalSnapshotSelectChain({ + currentPeriodCost = '0', + proPeriodCostSnapshot = '0', + proPeriodCostSnapshotAt = null, + lastPeriodCost = '0', +}: { + currentPeriodCost?: string + proPeriodCostSnapshot?: string + proPeriodCostSnapshotAt?: Date | null + lastPeriodCost?: string +}) { + return buildSelectChain([ + { + currentPeriodCost, + proPeriodCostSnapshot, + proPeriodCostSnapshotAt, + lastPeriodCost, + }, + ]) +} + +function buildStatsSelectChain() { + const result = { + for: vi.fn(() => result), + limit: mockTxStatsLimit, + then: (resolve: (value: unknown[]) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(mockTxStatsLimit()).then(resolve, reject), + } + + return { + from: vi.fn(() => ({ + leftJoin: vi.fn(() => ({ + innerJoin: vi.fn(() => ({ + where: vi.fn(() => result), + })), + })), + where: vi.fn(() => result), + })), + } +} + +function buildUpdateChain() { + return { + set: vi.fn(() => ({ + where: vi.fn(async () => []), + })), + } +} + +describe('checkAndBillOverageThreshold', () => { + beforeEach(() => { + vi.clearAllMocks() + + mockGetHighestPrioritySubscription.mockResolvedValue(userSubscription) + mockGetEffectiveBillingStatus.mockResolvedValue({ billingBlocked: false }) + mockHasUsableSubscriptionAccess.mockReturnValue(true) + mockIsFree.mockReturnValue(false) + mockIsEnterprise.mockReturnValue(false) + mockIsOrgScopedSubscription.mockReturnValue(false) + mockDbSelect.mockImplementation(() => buildPersonalSelectChain()) + mockTxSelect.mockImplementation(() => buildStatsSelectChain()) + mockTxUpdate.mockImplementation(() => buildUpdateChain()) + mockTxExecute.mockResolvedValue(undefined) + mockDbTransaction.mockImplementation(async (callback: (tx: MockTx) => Promise) => + callback({ execute: mockTxExecute, select: mockTxSelect, update: mockTxUpdate }) + ) + }) + + it('does not lock user_stats when calculated overage is below threshold', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(99) + + await checkAndBillOverageThreshold('user-1') + + expect(mockCalculateSubscriptionOverage).toHaveBeenCalledWith({ + id: userSubscription.id, + plan: userSubscription.plan, + referenceId: userSubscription.referenceId, + seats: userSubscription.seats, + periodStart: userSubscription.periodStart, + periodEnd: userSubscription.periodEnd, + }) + expect(mockDbTransaction).not.toHaveBeenCalled() + expect(mockDbSelect).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('calculates overage before opening the short user_stats transaction', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + billedOverageThisPeriod: '0', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockCalculateSubscriptionOverage).toHaveBeenCalled() + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockCalculateSubscriptionOverage.mock.invocationCallOrder[0]).toBeLessThan( + mockDbTransaction.mock.invocationCallOrder[0] + ) + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).toHaveBeenCalledTimes(1) + }) + + it('rechecks billed overage while locked before enqueueing an invoice', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + billedOverageThisPeriod: '200', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockTxUpdate).not.toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('skips personal threshold billing when locked usage inputs changed', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockDbSelect + .mockImplementationOnce(() => buildPersonalSnapshotSelectChain({ currentPeriodCost: '250' })) + .mockImplementationOnce(() => buildPersonalSelectChain()) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '250', + billedOverageThisPeriod: '0', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('computes organization overage before opening the locked transaction', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockComputeOrgOverageAmount).toHaveBeenCalledWith({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + organizationId: userSubscription.referenceId, + pooledCurrentPeriodCost: 350, + departedMemberUsage: 25, + memberIds: ['owner-1'], + }) + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockComputeOrgOverageAmount.mock.invocationCallOrder[0]).toBeLessThan( + mockDbTransaction.mock.invocationCallOrder[0] + ) + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).toHaveBeenCalledTimes(1) + }) + + it('skips stale organization overage when locked usage inputs changed', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '75' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '75', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) + + it('rechecks organization billed overage on the locked owner tracker', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '200' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) + + it('skips stale organization overage when owner identity changed', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + { + userId: 'member-1', + role: 'member', + currentPeriodCost: '25', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'member-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'member', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + { + userId: 'member-1', + role: 'owner', + currentPeriodCost: '25', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/lib/billing/threshold-billing.ts b/apps/sim/lib/billing/threshold-billing.ts index 1219481b7d2..86156b6adc1 100644 --- a/apps/sim/lib/billing/threshold-billing.ts +++ b/apps/sim/lib/billing/threshold-billing.ts @@ -1,8 +1,8 @@ import { db } from '@sim/db' import { member, organization, subscription, userStats } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { eq, inArray, sql } from 'drizzle-orm' -import { DEFAULT_OVERAGE_THRESHOLD } from '@/lib/billing/constants' +import { and, eq, sql } from 'drizzle-orm' +import { BILLING_LOCK_TIMEOUT_MS, DEFAULT_OVERAGE_THRESHOLD } from '@/lib/billing/constants' import { getEffectiveBillingStatus, isOrganizationBillingBlocked } from '@/lib/billing/core/access' import { calculateSubscriptionOverage, computeOrgOverageAmount } from '@/lib/billing/core/billing' import { @@ -22,6 +22,22 @@ import { enqueueOutboxEvent } from '@/lib/core/outbox/service' const logger = createLogger('ThresholdBilling') const OVERAGE_THRESHOLD = envNumber(env.OVERAGE_THRESHOLD_DOLLARS, DEFAULT_OVERAGE_THRESHOLD) +const USAGE_TOTAL_EPSILON = 0.000001 + +interface PersonalUsageSnapshot { + currentPeriodCost: number + proPeriodCostSnapshot: number + proPeriodCostSnapshotAt: Date | null + lastPeriodCost: number +} + +interface OrganizationUsageSnapshot { + memberIds: string[] + ownerId: string + memberSignature: string + pooledCurrentPeriodCost: number + departedMemberUsage: number +} export async function checkAndBillOverageThreshold(userId: string): Promise { try { @@ -53,7 +69,57 @@ export async function checkAndBillOverageThreshold(userId: string): Promise { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + const statsRecords = await tx .select() .from(userStats) @@ -67,15 +133,16 @@ export async function checkAndBillOverageThreshold(userId: string): Promise ({ userId: m.userId, role: m.role })), + memberCount: memberUsageRows.length, + members: memberUsageRows.map((m) => ({ userId: m.userId, role: m.role })), }) - if (members.length === 0) { + if (memberUsageRows.length === 0) { logger.warn('No members found for organization', { organizationId }) return } - const owner = members.find((m) => m.role === 'owner') - if (!owner) { + const usageSnapshot = buildOrganizationUsageSnapshot(memberUsageRows) + if (!usageSnapshot) { logger.error( 'Organization has no owner when running threshold billing — data integrity issue, skipping', { organizationId } @@ -260,17 +312,80 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): logger.debug('Found organization owner, starting transaction', { organizationId, - ownerId: owner.userId, + ownerId: usageSnapshot.ownerId, + }) + + const { + totalOverage: currentOverage, + baseSubscriptionAmount: basePrice, + effectiveUsage: effectiveTeamUsage, + } = await computeOrgOverageAmount({ + plan: orgSubscription.plan, + seats: orgSubscription.seats ?? null, + periodStart: orgSubscription.periodStart ?? null, + periodEnd: orgSubscription.periodEnd ?? null, + organizationId, + pooledCurrentPeriodCost: usageSnapshot.pooledCurrentPeriodCost, + departedMemberUsage: usageSnapshot.departedMemberUsage, + memberIds: usageSnapshot.memberIds, }) + if (currentOverage < threshold) { + logger.debug('Organization threshold billing check below threshold before locking', { + organizationId, + totalTeamUsage: usageSnapshot.pooledCurrentPeriodCost + usageSnapshot.departedMemberUsage, + effectiveTeamUsage, + basePrice, + currentOverage, + threshold, + }) + return + } + + // Validate Stripe identifiers BEFORE mutating credits/trackers. + const stripeSubscriptionId = orgSubscription.stripeSubscriptionId + if (!stripeSubscriptionId) { + logger.error('No Stripe subscription ID for organization', { organizationId }) + return + } + + const customerId = orgSubscription.stripeCustomerId + if (!customerId) { + logger.error('No Stripe customer ID for organization', { organizationId }) + return + } + + const periodEnd = orgSubscription.periodEnd + ? Math.floor(orgSubscription.periodEnd.getTime() / 1000) + : Math.floor(Date.now() / 1000) + const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7) + const totalOverageCents = Math.round(currentOverage * 100) + await db.transaction(async (tx) => { - // Lock both owner stats and organization rows + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + + const lockedOwnerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, organizationId), eq(member.role, 'owner'))) + .for('update') + .limit(1) + const lockedOwnerId = lockedOwnerRows[0]?.userId + if (!lockedOwnerId) { + logger.error('Organization owner not found after locking organization', { organizationId }) + return + } + const ownerStatsLock = await tx .select() .from(userStats) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) .for('update') .limit(1) + if (ownerStatsLock.length === 0) { + logger.error('Owner stats not found', { organizationId, ownerId: lockedOwnerId }) + return + } const orgLock = await tx .select() @@ -279,58 +394,46 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .for('update') .limit(1) - if (ownerStatsLock.length === 0) { - logger.error('Owner stats not found', { organizationId, ownerId: owner.userId }) + if (orgLock.length === 0) { + logger.error('Organization not found', { organizationId }) return } - if (orgLock.length === 0) { - logger.error('Organization not found', { organizationId }) + const lockedMemberUsageRows = await tx + .select({ + userId: member.userId, + role: member.role, + currentPeriodCost: userStats.currentPeriodCost, + departedMemberUsage: organization.departedMemberUsage, + }) + .from(member) + .leftJoin(userStats, eq(member.userId, userStats.userId)) + .innerJoin(organization, eq(organization.id, member.organizationId)) + .where(eq(member.organizationId, organizationId)) + + const lockedUsageSnapshot = buildOrganizationUsageSnapshot(lockedMemberUsageRows) + if ( + !lockedUsageSnapshot || + lockedOwnerId !== usageSnapshot.ownerId || + !organizationUsageSnapshotMatches(usageSnapshot, lockedUsageSnapshot) + ) { + logger.debug('Organization usage changed during threshold billing check; retry later', { + organizationId, + usageSnapshot, + lockedUsageSnapshot, + lockedOwnerId, + }) return } - let pooledCurrentPeriodCost = toNumber(toDecimal(ownerStatsLock[0].currentPeriodCost)) const totalBilledOverage = toNumber(toDecimal(ownerStatsLock[0].billedOverageThisPeriod)) const orgCreditBalance = toNumber(toDecimal(orgLock[0].creditBalance)) - const nonOwnerIds = members.filter((m) => m.userId !== owner.userId).map((m) => m.userId) - - if (nonOwnerIds.length > 0) { - const memberStatsRows = await tx - .select({ - userId: userStats.userId, - currentPeriodCost: userStats.currentPeriodCost, - }) - .from(userStats) - .where(inArray(userStats.userId, nonOwnerIds)) - - for (const stats of memberStatsRows) { - pooledCurrentPeriodCost += toNumber(toDecimal(stats.currentPeriodCost)) - } - } - - const departedMemberUsage = toNumber(toDecimal(orgLock[0].departedMemberUsage)) - - const { - totalOverage: currentOverage, - baseSubscriptionAmount: basePrice, - effectiveUsage: effectiveTeamUsage, - } = await computeOrgOverageAmount({ - plan: orgSubscription.plan, - seats: orgSubscription.seats ?? null, - periodStart: orgSubscription.periodStart ?? null, - periodEnd: orgSubscription.periodEnd ?? null, - organizationId, - pooledCurrentPeriodCost, - departedMemberUsage, - memberIds: members.map((m) => m.userId), - }) - const unbilledOverage = Math.max(0, currentOverage - totalBilledOverage) logger.debug('Organization threshold billing check', { organizationId, - totalTeamUsage: pooledCurrentPeriodCost + departedMemberUsage, + totalTeamUsage: usageSnapshot.pooledCurrentPeriodCost + usageSnapshot.departedMemberUsage, effectiveTeamUsage, basePrice, currentOverage, @@ -343,19 +446,6 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): return } - // Validate Stripe identifiers BEFORE mutating credits/trackers. - const stripeSubscriptionId = orgSubscription.stripeSubscriptionId - if (!stripeSubscriptionId) { - logger.error('No Stripe subscription ID for organization', { organizationId }) - return - } - - const customerId = orgSubscription.stripeCustomerId - if (!customerId) { - logger.error('No Stripe customer ID for organization', { organizationId }) - return - } - let amountToBill = unbilledOverage let creditsApplied = 0 @@ -384,7 +474,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .set({ billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${unbilledOverage}`, }) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) logger.info('Credits fully covered org threshold overage', { organizationId, @@ -394,12 +484,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): return } - const periodEnd = orgSubscription.periodEnd - ? Math.floor(orgSubscription.periodEnd.getTime() / 1000) - : Math.floor(Date.now() / 1000) - const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7) const amountCents = Math.round(amountToBill * 100) - const totalOverageCents = Math.round(currentOverage * 100) // Bump billed tracker and enqueue Stripe invoice atomically. // See user-path above for the full retry-invariant reasoning. @@ -408,7 +493,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .set({ billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${unbilledOverage}`, }) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) await enqueueOutboxEvent(tx, OUTBOX_EVENT_TYPES.STRIPE_THRESHOLD_OVERAGE_INVOICE, { customerId, @@ -430,7 +515,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): logger.info('Queued organization threshold overage invoice for Stripe', { organizationId, - ownerId: owner.userId, + ownerId: lockedOwnerId, creditsApplied, amountBilled: amountToBill, totalProcessed: unbilledOverage, @@ -444,3 +529,92 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): }) } } + +async function getPersonalUsageSnapshot(userId: string): Promise { + const [stats] = await db + .select({ + currentPeriodCost: userStats.currentPeriodCost, + proPeriodCostSnapshot: userStats.proPeriodCostSnapshot, + proPeriodCostSnapshotAt: userStats.proPeriodCostSnapshotAt, + lastPeriodCost: userStats.lastPeriodCost, + }) + .from(userStats) + .where(eq(userStats.userId, userId)) + .limit(1) + + return stats ? personalUsageSnapshotFromStats(stats) : null +} + +function personalUsageSnapshotFromStats(stats: { + currentPeriodCost: string | number | null + proPeriodCostSnapshot: string | number | null + proPeriodCostSnapshotAt: Date | null + lastPeriodCost: string | number | null +}): PersonalUsageSnapshot { + return { + currentPeriodCost: toNumber(toDecimal(stats.currentPeriodCost)), + proPeriodCostSnapshot: toNumber(toDecimal(stats.proPeriodCostSnapshot)), + proPeriodCostSnapshotAt: stats.proPeriodCostSnapshotAt, + lastPeriodCost: toNumber(toDecimal(stats.lastPeriodCost)), + } +} + +function personalUsageSnapshotMatches( + expected: PersonalUsageSnapshot, + actual: PersonalUsageSnapshot +): boolean { + return ( + Math.abs(expected.currentPeriodCost - actual.currentPeriodCost) <= USAGE_TOTAL_EPSILON && + Math.abs(expected.proPeriodCostSnapshot - actual.proPeriodCostSnapshot) <= + USAGE_TOTAL_EPSILON && + Math.abs(expected.lastPeriodCost - actual.lastPeriodCost) <= USAGE_TOTAL_EPSILON && + nullableDateTime(expected.proPeriodCostSnapshotAt) === + nullableDateTime(actual.proPeriodCostSnapshotAt) + ) +} + +function buildOrganizationUsageSnapshot( + rows: { + userId: string + role: string + currentPeriodCost: string | number | null + departedMemberUsage: string | number | null + }[] +): OrganizationUsageSnapshot | null { + const owner = rows.find((row) => row.role === 'owner') + if (!owner) return null + + const sortedRows = [...rows].sort((a, b) => a.userId.localeCompare(b.userId)) + let pooledCurrentPeriodCost = 0 + for (const row of sortedRows) { + pooledCurrentPeriodCost += toNumber(toDecimal(row.currentPeriodCost)) + } + + return { + memberIds: sortedRows.map((row) => row.userId), + ownerId: owner.userId, + memberSignature: sortedRows + .map( + (row) => + `${row.userId}:${row.role}:${toNumber(toDecimal(row.currentPeriodCost)).toFixed(6)}` + ) + .join('|'), + pooledCurrentPeriodCost, + departedMemberUsage: toNumber(toDecimal(owner.departedMemberUsage)), + } +} + +function organizationUsageSnapshotMatches( + expected: OrganizationUsageSnapshot, + actual: OrganizationUsageSnapshot +): boolean { + return ( + expected.ownerId === actual.ownerId && + expected.memberSignature === actual.memberSignature && + Math.abs(expected.departedMemberUsage - actual.departedMemberUsage) <= USAGE_TOTAL_EPSILON + ) +} + +function nullableDateTime(value: Date | null): number | null { + return value?.getTime() ?? null +} diff --git a/apps/sim/lib/billing/webhooks/invoices.test.ts b/apps/sim/lib/billing/webhooks/invoices.test.ts index 9d601c33a8c..eabf87666b5 100644 --- a/apps/sim/lib/billing/webhooks/invoices.test.ts +++ b/apps/sim/lib/billing/webhooks/invoices.test.ts @@ -103,6 +103,7 @@ vi.mock('@react-email/render', () => ({ import { handleInvoicePaymentFailed, handleInvoicePaymentSucceeded, + resetUsageForSubscription, } from '@/lib/billing/webhooks/invoices' interface SelectResponse { @@ -127,6 +128,7 @@ function installSelectResponseQueue() { throw new Error('No queued db.select response') } const builder = { + for: vi.fn(() => builder), limit: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), orderBy: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), returning: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), @@ -223,4 +225,40 @@ describe('invoice billing recovery', () => { expect(mockUnblockOrgMembers).toHaveBeenCalledWith('org-1', 'payment_failed') expect(mockBlockOrgMembers).not.toHaveBeenCalled() }) + + it('coordinates org usage reset with owner tracker and organization locks', async () => { + queueSelectResponse({ limitResult: [{ userId: 'owner-1' }] }) + queueSelectResponse({ limitResult: [{ userId: 'owner-1' }] }) + queueSelectResponse({ limitResult: [{ id: 'org-1' }] }) + queueSelectResponse({ whereResult: [{ userId: 'owner-1' }, { userId: 'member-1' }] }) + queueSelectResponse({ + whereResult: [ + { userId: 'owner-1', current: '125', currentCopilot: '10' }, + { userId: 'member-1', current: '75', currentCopilot: '5' }, + ], + }) + queueSelectResponse({ whereResult: [] }) + queueSelectResponse({ whereResult: [] }) + + await resetUsageForSubscription({ plan: 'team', referenceId: 'org-1' }) + + expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.update).toHaveBeenCalledTimes(2) + expect(Object.keys(dbChainMockFns.select.mock.calls[0][0] ?? {})).toEqual(['userId']) + expect(Object.keys(dbChainMockFns.select.mock.calls[1][0] ?? {})).toEqual(['userId']) + expect(Object.keys(dbChainMockFns.select.mock.calls[2][0] ?? {})).toEqual(['id']) + + const statsReset = dbChainMockFns.set.mock.calls[0][0] as Record + expect(statsReset.currentPeriodCost).not.toBe('0') + expect(statsReset.currentPeriodCopilotCost).not.toBe('0') + expect(statsReset.lastPeriodCost).toMatchObject({ + toSQL: expect.any(Function), + }) + expect((statsReset.lastPeriodCost as { toSQL: () => { sql: string } }).toSQL().sql).toContain( + 'CASE' + ) + expect( + (statsReset.currentPeriodCost as { toSQL: () => { sql: string } }).toSQL().sql + ).toContain('GREATEST') + }) }) diff --git a/apps/sim/lib/billing/webhooks/invoices.ts b/apps/sim/lib/billing/webhooks/invoices.ts index bed1a7834e4..f3f6bd40576 100644 --- a/apps/sim/lib/billing/webhooks/invoices.ts +++ b/apps/sim/lib/billing/webhooks/invoices.ts @@ -11,6 +11,7 @@ import { createLogger } from '@sim/logger' import { and, eq, inArray, isNull, ne, or, sql } from 'drizzle-orm' import type Stripe from 'stripe' import { getEmailSubject, PaymentFailedEmail, renderCreditPurchaseEmail } from '@/components/emails' +import { BILLING_LOCK_TIMEOUT_MS } from '@/lib/billing/constants' import { calculateSubscriptionOverage, isSubscriptionOrgScoped } from '@/lib/billing/core/billing' import { addCredits, getCreditBalanceForEntity } from '@/lib/billing/credits/balance' import { setUsageLimitForCredits } from '@/lib/billing/credits/purchase' @@ -388,40 +389,86 @@ export async function getBilledOverageForSubscription(sub: { export async function resetUsageForSubscription(sub: { plan: string | null; referenceId: string }) { if (await isSubscriptionOrgScoped(sub)) { - const membersRows = await db - .select({ userId: member.userId }) - .from(member) - .where(eq(member.organizationId, sub.referenceId)) + await db.transaction(async (tx) => { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) - for (const m of membersRows) { - const currentStats = await db - .select({ - current: userStats.currentPeriodCost, - currentCopilot: userStats.currentPeriodCopilotCost, - }) - .from(userStats) - .where(eq(userStats.userId, m.userId)) + const ownerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, sub.referenceId), eq(member.role, 'owner'))) + .for('update') .limit(1) - if (currentStats.length > 0) { - const current = currentStats[0].current || '0' - const currentCopilot = currentStats[0].currentCopilot || '0' - await db + + const ownerId = ownerRows[0]?.userId + if (ownerId) { + await tx + .select({ userId: userStats.userId }) + .from(userStats) + .where(eq(userStats.userId, ownerId)) + .for('update') + .limit(1) + } + + await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, sub.referenceId)) + .for('update') + .limit(1) + + const membersRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(eq(member.organizationId, sub.referenceId)) + + const memberIds = membersRows.map((row) => row.userId) + if (memberIds.length > 0) { + const memberStatsRows = await tx + .select({ + userId: userStats.userId, + current: userStats.currentPeriodCost, + currentCopilot: userStats.currentPeriodCopilotCost, + }) + .from(userStats) + .where(inArray(userStats.userId, memberIds)) + + const statsUserIds = memberStatsRows.map((row) => row.userId) + if (statsUserIds.length === 0) { + await tx + .update(organization) + .set({ departedMemberUsage: '0' }) + .where(eq(organization.id, sub.referenceId)) + return + } + + const currentCostByUser = sql.join( + memberStatsRows.map((row) => sql`WHEN ${row.userId} THEN ${row.current ?? '0'}`), + sql` ` + ) + const currentCopilotCostByUser = sql.join( + memberStatsRows.map((row) => sql`WHEN ${row.userId} THEN ${row.currentCopilot ?? '0'}`), + sql` ` + ) + const capturedCurrentCost = sql`CASE ${userStats.userId} ${currentCostByUser} ELSE '0' END` + const capturedCurrentCopilotCost = sql`CASE ${userStats.userId} ${currentCopilotCostByUser} ELSE '0' END` + + await tx .update(userStats) .set({ - lastPeriodCost: current, - lastPeriodCopilotCost: currentCopilot, - currentPeriodCost: sql`GREATEST(0, ${userStats.currentPeriodCost} - ${current}::decimal)`, - currentPeriodCopilotCost: sql`GREATEST(0, ${userStats.currentPeriodCopilotCost} - ${currentCopilot}::decimal)`, + lastPeriodCost: capturedCurrentCost, + lastPeriodCopilotCost: capturedCurrentCopilotCost, + currentPeriodCost: sql`GREATEST(0, ${userStats.currentPeriodCost} - (${capturedCurrentCost})::decimal)`, + currentPeriodCopilotCost: sql`GREATEST(0, ${userStats.currentPeriodCopilotCost} - (${capturedCurrentCopilotCost})::decimal)`, billedOverageThisPeriod: '0', }) - .where(eq(userStats.userId, m.userId)) + .where(inArray(userStats.userId, statsUserIds)) } - } - await db - .update(organization) - .set({ departedMemberUsage: '0' }) - .where(eq(organization.id, sub.referenceId)) + await tx + .update(organization) + .set({ departedMemberUsage: '0' }) + .where(eq(organization.id, sub.referenceId)) + }) } else { const currentStats = await db .select({ @@ -859,36 +906,29 @@ export async function handleInvoiceFinalized(event: Stripe.Event) { const entityType = (await isSubscriptionOrgScoped(sub)) ? 'organization' : 'user' const entityId = sub.referenceId - // Resolve the userStats row that holds the `billedOverageThisPeriod` - // tracker. Org subs: the owner's row. Personal: the user's own row. - // Throw if an org has no owner — returning early would cache a - // "successful" no-op, and the next cycle's tracker would still - // reflect this cycle's billed amount, breaking future overage math. - let trackerUserId: string - if (entityType === 'organization') { - const ownerRows = await db - .select({ userId: member.userId }) - .from(member) - .where(and(eq(member.organizationId, entityId), eq(member.role, 'owner'))) - .limit(1) - const ownerId = ownerRows[0]?.userId - if (!ownerId) { - throw new Error( - `Organization ${entityId} has no owner member; cannot process invoice finalization` - ) + // Phase 1 — atomic commit. Resolve org owners inside the transaction, + // then lock the tracker row so `billedOverageThisPeriod` is serialized + // against threshold billing, resets, owner transfers, and retries. + const phase1 = await db.transaction(async (tx) => { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + + let trackerUserId = entityId + if (entityType === 'organization') { + const ownerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, entityId), eq(member.role, 'owner'))) + .for('update') + .limit(1) + const ownerId = ownerRows[0]?.userId + if (!ownerId) { + throw new Error( + `Organization ${entityId} has no owner member; cannot process invoice finalization` + ) + } + trackerUserId = ownerId } - trackerUserId = ownerId - } else { - trackerUserId = entityId - } - // Phase 1 — atomic commit. Lock the tracker row first so we read - // `billedOverageThisPeriod` serialized against concurrent events; - // then read the credit balance, decrement it, and bump the - // tracker to `totalOverage`. On retry, the locked re-read sees - // `billed == totalOverage` → `remaining == 0` → credit removal - // skipped. That's the invariant preventing double-deduction. - const phase1 = await db.transaction(async (tx) => { const trackerRows = await tx .select({ billed: userStats.billedOverageThisPeriod }) .from(userStats) From c3ac54e0a967760a938e24912f7985f2e37e5e4a Mon Sep 17 00:00:00 2001 From: Waleed Date: Thu, 14 May 2026 00:03:30 -0700 Subject: [PATCH 5/5] fix(vfs): make copilot message ordering deterministic via WITH ORDINALITY (#4597) --- apps/sim/lib/copilot/vfs/workspace-vfs.ts | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/apps/sim/lib/copilot/vfs/workspace-vfs.ts b/apps/sim/lib/copilot/vfs/workspace-vfs.ts index 3db1683bd19..3e7b45b9cb8 100644 --- a/apps/sim/lib/copilot/vfs/workspace-vfs.ts +++ b/apps/sim/lib/copilot/vfs/workspace-vfs.ts @@ -1161,22 +1161,23 @@ export class WorkspaceVFS { messages: sql`COALESCE(( SELECT jsonb_agg( jsonb_build_object( - 'role', m->>'role', - 'content', m->'content', + 'role', m.value->>'role', + 'content', m.value->'content', 'contentBlocks', COALESCE(( - SELECT jsonb_agg(jsonb_build_object('type', 'text', 'content', b->'content')) + SELECT jsonb_agg(jsonb_build_object('type', 'text', 'content', b.value->'content') ORDER BY b.ord) FROM jsonb_array_elements( - CASE WHEN jsonb_typeof(m->'contentBlocks') = 'array' - THEN m->'contentBlocks' + CASE WHEN jsonb_typeof(m.value->'contentBlocks') = 'array' + THEN m.value->'contentBlocks' ELSE '[]'::jsonb END - ) AS b - WHERE b->>'type' = 'text' + ) WITH ORDINALITY AS b(value, ord) + WHERE b.value->>'type' = 'text' ), '[]'::jsonb) ) + ORDER BY m.ord ) - FROM jsonb_array_elements(${copilotChats.messages}) AS m - WHERE m->>'role' IN ('user', 'assistant') + FROM jsonb_array_elements(${copilotChats.messages}) WITH ORDINALITY AS m(value, ord) + WHERE m.value->>'role' IN ('user', 'assistant') ), '[]'::jsonb)`, createdAt: copilotChats.createdAt, updatedAt: copilotChats.updatedAt,