mirror of
https://github.com/mgeeky/decode-spam-headers.git
synced 2026-02-22 05:23:31 +01:00
143 lines
4.2 KiB
Python
143 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import math
|
|
import time
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from typing import Deque
|
|
|
|
from fastapi import status
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.security.captcha import (
|
|
BYPASS_TOKEN_HEADER,
|
|
create_captcha_challenge,
|
|
verify_bypass_token,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaptchaChallengePayload:
|
|
challenge_token: str
|
|
image_base64: str
|
|
|
|
|
|
class SlidingWindowRateLimiter:
|
|
def __init__(self, limit: int, window_seconds: int) -> None:
|
|
self._limit = max(1, int(limit))
|
|
self._window_seconds = max(1, int(window_seconds))
|
|
self._lock = asyncio.Lock()
|
|
self._requests: dict[str, Deque[float]] = {}
|
|
|
|
async def check(self, key: str) -> tuple[bool, int]:
|
|
now = time.monotonic()
|
|
async with self._lock:
|
|
bucket = self._requests.get(key)
|
|
if bucket is None:
|
|
bucket = deque()
|
|
self._requests[key] = bucket
|
|
|
|
cutoff = now - self._window_seconds
|
|
while bucket and bucket[0] <= cutoff:
|
|
bucket.popleft()
|
|
|
|
if len(bucket) >= self._limit:
|
|
retry_after = self._compute_retry_after(bucket, now)
|
|
return False, retry_after
|
|
|
|
bucket.append(now)
|
|
return True, 0
|
|
|
|
def _compute_retry_after(self, bucket: Deque[float], now: float) -> int:
|
|
if not bucket:
|
|
return self._window_seconds
|
|
oldest = bucket[0]
|
|
remaining = self._window_seconds - (now - oldest)
|
|
return max(1, math.ceil(remaining))
|
|
|
|
|
|
class RateLimiterMiddleware:
|
|
def __init__(
|
|
self,
|
|
app,
|
|
*,
|
|
limiter: SlidingWindowRateLimiter,
|
|
protected_paths: set[str] | None = None,
|
|
) -> None:
|
|
self.app = app
|
|
self.limiter = limiter
|
|
self.protected_paths = protected_paths or set()
|
|
|
|
async def __call__(self, scope, receive, send) -> None:
|
|
if scope.get("type") != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
path = scope.get("path", "")
|
|
if self.protected_paths and path not in self.protected_paths:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
if scope.get("method", "").upper() == "OPTIONS":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
client_ip = _get_client_ip(scope)
|
|
bypass_token = _get_bypass_token(scope)
|
|
if bypass_token and verify_bypass_token(bypass_token, client_ip):
|
|
await self.app(scope, receive, send)
|
|
return
|
|
allowed, retry_after = await self.limiter.check(client_ip)
|
|
if allowed:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
challenge = _create_captcha_challenge()
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"error": "Too many requests",
|
|
"retryAfter": retry_after,
|
|
"captchaChallenge": {
|
|
"challengeToken": challenge.challenge_token,
|
|
"imageBase64": challenge.image_base64,
|
|
},
|
|
},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
await response(scope, receive, send)
|
|
|
|
|
|
def _create_captcha_challenge() -> CaptchaChallengePayload:
|
|
challenge = create_captcha_challenge()
|
|
return CaptchaChallengePayload(
|
|
challenge_token=challenge.challenge_token,
|
|
image_base64=challenge.image_base64,
|
|
)
|
|
|
|
|
|
def _get_client_ip(scope) -> str:
|
|
headers = scope.get("headers") or []
|
|
forwarded_for = None
|
|
for key, value in headers:
|
|
if key.lower() == b"x-forwarded-for":
|
|
forwarded_for = value.decode("utf-8", errors="ignore")
|
|
break
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip() or "unknown"
|
|
|
|
client = scope.get("client")
|
|
if client and client[0]:
|
|
return str(client[0])
|
|
return "unknown"
|
|
|
|
|
|
def _get_bypass_token(scope) -> str | None:
|
|
headers = scope.get("headers") or []
|
|
token_header = BYPASS_TOKEN_HEADER.encode("utf-8")
|
|
for key, value in headers:
|
|
if key.lower() == token_header:
|
|
return value.decode("utf-8", errors="ignore")
|
|
return None
|