from __future__ import annotations import asyncio import math import time from collections import deque from dataclasses import dataclass from typing import Deque from fastapi import status from fastapi.responses import JSONResponse from app.security.captcha import ( BYPASS_TOKEN_HEADER, create_captcha_challenge, verify_bypass_token, ) @dataclass(frozen=True) class CaptchaChallengePayload: challenge_token: str image_base64: str class SlidingWindowRateLimiter: def __init__(self, limit: int, window_seconds: int) -> None: self._limit = max(1, int(limit)) self._window_seconds = max(1, int(window_seconds)) self._lock = asyncio.Lock() self._requests: dict[str, Deque[float]] = {} async def check(self, key: str) -> tuple[bool, int]: now = time.monotonic() async with self._lock: bucket = self._requests.get(key) if bucket is None: bucket = deque() self._requests[key] = bucket cutoff = now - self._window_seconds while bucket and bucket[0] <= cutoff: bucket.popleft() if len(bucket) >= self._limit: retry_after = self._compute_retry_after(bucket, now) return False, retry_after bucket.append(now) return True, 0 def _compute_retry_after(self, bucket: Deque[float], now: float) -> int: if not bucket: return self._window_seconds oldest = bucket[0] remaining = self._window_seconds - (now - oldest) return max(1, math.ceil(remaining)) class RateLimiterMiddleware: def __init__( self, app, *, limiter: SlidingWindowRateLimiter, protected_paths: set[str] | None = None, ) -> None: self.app = app self.limiter = limiter self.protected_paths = protected_paths or set() async def __call__(self, scope, receive, send) -> None: if scope.get("type") != "http": await self.app(scope, receive, send) return path = scope.get("path", "") if self.protected_paths and path not in self.protected_paths: await self.app(scope, receive, send) return if scope.get("method", "").upper() == "OPTIONS": await self.app(scope, receive, send) return client_ip = _get_client_ip(scope) bypass_token = _get_bypass_token(scope) if bypass_token and verify_bypass_token(bypass_token, client_ip): await self.app(scope, receive, send) return allowed, retry_after = await self.limiter.check(client_ip) if allowed: await self.app(scope, receive, send) return challenge = _create_captcha_challenge() response = JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "error": "Too many requests", "retryAfter": retry_after, "captchaChallenge": { "challengeToken": challenge.challenge_token, "imageBase64": challenge.image_base64, }, }, headers={"Retry-After": str(retry_after)}, ) await response(scope, receive, send) def _create_captcha_challenge() -> CaptchaChallengePayload: challenge = create_captcha_challenge() return CaptchaChallengePayload( challenge_token=challenge.challenge_token, image_base64=challenge.image_base64, ) def _get_client_ip(scope) -> str: headers = scope.get("headers") or [] forwarded_for = None for key, value in headers: if key.lower() == b"x-forwarded-for": forwarded_for = value.decode("utf-8", errors="ignore") break if forwarded_for: return forwarded_for.split(",")[0].strip() or "unknown" client = scope.get("client") if client and client[0]: return str(client[0]) return "unknown" def _get_bypass_token(scope) -> str | None: headers = scope.get("headers") or [] token_header = BYPASS_TOKEN_HEADER.encode("utf-8") for key, value in headers: if key.lower() == token_header: return value.decode("utf-8", errors="ignore") return None