Files
mgeeky-decode-spam-headers/backend/app/middleware/rate_limiter.py
2026-02-18 04:06:43 +01:00

128 lines
3.8 KiB
Python

from __future__ import annotations
import asyncio
import math
import secrets
import time
from collections import deque
from dataclasses import dataclass
from typing import Deque
from fastapi import status
from fastapi.responses import JSONResponse
CAPTCHA_PLACEHOLDER_IMAGE_BASE64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO6sZcQAAAAASUVORK5CYII="
)
@dataclass(frozen=True)
class CaptchaChallengePayload:
challenge_token: str
image_base64: str
class SlidingWindowRateLimiter:
def __init__(self, limit: int, window_seconds: int) -> None:
self._limit = max(1, int(limit))
self._window_seconds = max(1, int(window_seconds))
self._lock = asyncio.Lock()
self._requests: dict[str, Deque[float]] = {}
async def check(self, key: str) -> tuple[bool, int]:
now = time.monotonic()
async with self._lock:
bucket = self._requests.get(key)
if bucket is None:
bucket = deque()
self._requests[key] = bucket
cutoff = now - self._window_seconds
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= self._limit:
retry_after = self._compute_retry_after(bucket, now)
return False, retry_after
bucket.append(now)
return True, 0
def _compute_retry_after(self, bucket: Deque[float], now: float) -> int:
if not bucket:
return self._window_seconds
oldest = bucket[0]
remaining = self._window_seconds - (now - oldest)
return max(1, math.ceil(remaining))
class RateLimiterMiddleware:
def __init__(
self,
app,
*,
limiter: SlidingWindowRateLimiter,
protected_paths: set[str] | None = None,
) -> None:
self.app = app
self.limiter = limiter
self.protected_paths = protected_paths or set()
async def __call__(self, scope, receive, send) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
path = scope.get("path", "")
if self.protected_paths and path not in self.protected_paths:
await self.app(scope, receive, send)
return
if scope.get("method", "").upper() == "OPTIONS":
await self.app(scope, receive, send)
return
client_ip = _get_client_ip(scope)
allowed, retry_after = await self.limiter.check(client_ip)
if allowed:
await self.app(scope, receive, send)
return
challenge = _create_captcha_challenge()
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "Too many requests",
"retryAfter": retry_after,
"captchaChallenge": {
"challengeToken": challenge.challenge_token,
"imageBase64": challenge.image_base64,
},
},
headers={"Retry-After": str(retry_after)},
)
await response(scope, receive, send)
def _create_captcha_challenge() -> CaptchaChallengePayload:
return CaptchaChallengePayload(
challenge_token=secrets.token_urlsafe(16),
image_base64=CAPTCHA_PLACEHOLDER_IMAGE_BASE64,
)
def _get_client_ip(scope) -> str:
headers = scope.get("headers") or []
forwarded_for = None
for key, value in headers:
if key.lower() == b"x-forwarded-for":
forwarded_for = value.decode("utf-8", errors="ignore")
break
if forwarded_for:
return forwarded_for.split(",")[0].strip() or "unknown"
client = scope.get("client")
if client and client[0]:
return str(client[0])
return "unknown"