mirror of
https://github.com/mgeeky/decode-spam-headers.git
synced 2026-02-22 13:33:30 +01:00
185 lines
6.1 KiB
Python
185 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from time import perf_counter
|
|
from typing import Any
|
|
|
|
import anyio
|
|
from anyio import BrokenResourceError, ClosedResourceError
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import ValidationError
|
|
|
|
from app.core.config import get_settings
|
|
from app.engine.analyzer import HeaderAnalyzer
|
|
from app.engine.models import AnalysisResult, ReportMetadata
|
|
from app.engine.scanner_registry import ScannerRegistry
|
|
from app.schemas.analysis import AnalysisProgress, AnalysisRequest
|
|
|
|
MAX_HEADERS_LENGTH = 1_048_576
|
|
|
|
router = APIRouter(prefix="/api", tags=["analysis"])
|
|
|
|
|
|
def _sanitize_headers(value: str) -> str:
|
|
if not value:
|
|
return value
|
|
cleaned = value.replace("\x00", "")
|
|
cleaned = "".join(
|
|
ch for ch in cleaned if ch in ("\n", "\r", "\t") or ch.isprintable()
|
|
)
|
|
return cleaned
|
|
|
|
|
|
def _format_sse(event: str, data: object) -> str:
|
|
payload = json.dumps(data)
|
|
return f"event: {event}\ndata: {payload}\n\n"
|
|
|
|
|
|
def _build_timeout_result(
|
|
expected_tests: list[str],
|
|
last_index: int | None,
|
|
elapsed_ms: float,
|
|
) -> AnalysisResult:
|
|
total_tests = len(expected_tests)
|
|
if last_index is None:
|
|
incomplete = list(expected_tests)
|
|
else:
|
|
incomplete = expected_tests[last_index:]
|
|
|
|
metadata = ReportMetadata(
|
|
total_tests=total_tests,
|
|
passed_tests=0,
|
|
failed_tests=0,
|
|
skipped_tests=0,
|
|
elapsed_ms=elapsed_ms,
|
|
timed_out=True,
|
|
incomplete_tests=incomplete,
|
|
)
|
|
return AnalysisResult(results=[], metadata=metadata)
|
|
|
|
|
|
@router.post("/analyse")
|
|
async def analyse(request: Request) -> StreamingResponse:
|
|
try:
|
|
payload = await request.json()
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON payload") from exc
|
|
|
|
if not isinstance(payload, dict):
|
|
raise HTTPException(status_code=400, detail="Invalid request payload")
|
|
|
|
headers = payload.get("headers")
|
|
if not isinstance(headers, str):
|
|
raise HTTPException(status_code=400, detail="Headers must be a string")
|
|
|
|
if not headers.strip():
|
|
return JSONResponse(
|
|
status_code=400, content={"error": "Headers cannot be empty"}
|
|
)
|
|
|
|
if len(headers) > MAX_HEADERS_LENGTH:
|
|
return JSONResponse(status_code=413, content={"error": "Headers exceed 1 MB"})
|
|
|
|
headers = _sanitize_headers(headers)
|
|
if not headers.strip():
|
|
return JSONResponse(
|
|
status_code=400, content={"error": "Headers cannot be empty"}
|
|
)
|
|
|
|
config = payload.get("config") or {}
|
|
try:
|
|
analysis_request = AnalysisRequest.model_validate(
|
|
{"headers": headers, "config": config}
|
|
)
|
|
except ValidationError as exc:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"error": "Invalid analysis request", "detail": exc.errors()},
|
|
)
|
|
|
|
settings = get_settings()
|
|
analyzer = HeaderAnalyzer()
|
|
registry = ScannerRegistry()
|
|
if analysis_request.config.test_ids:
|
|
scanners = registry.get_by_ids(analysis_request.config.test_ids)
|
|
else:
|
|
scanners = registry.get_all()
|
|
expected_tests = [scanner.name for scanner in scanners]
|
|
|
|
start = perf_counter()
|
|
progress_state: dict[str, Any] = {"last_index": None}
|
|
send_stream, receive_stream = anyio.create_memory_object_stream[
|
|
tuple[str, dict[str, Any]]
|
|
](max_buffer_size=0)
|
|
|
|
def on_progress(current_index: int, total_tests: int, test_name: str) -> None:
|
|
progress_state["last_index"] = current_index
|
|
elapsed_ms = (perf_counter() - start) * 1000
|
|
percentage = 0.0
|
|
if total_tests > 0:
|
|
percentage = min(100.0, (current_index + 1) / total_tests * 100.0)
|
|
payload = AnalysisProgress(
|
|
current_index=current_index,
|
|
total_tests=total_tests,
|
|
current_test=test_name,
|
|
elapsed_ms=elapsed_ms,
|
|
percentage=percentage,
|
|
).model_dump(by_alias=True)
|
|
try:
|
|
anyio.from_thread.run(send_stream.send, ("progress", payload))
|
|
except (BrokenResourceError, ClosedResourceError, RuntimeError):
|
|
pass
|
|
|
|
async def run_analysis() -> None:
|
|
result: AnalysisResult | None = None
|
|
timed_out = False
|
|
async with send_stream:
|
|
try:
|
|
with anyio.fail_after(settings.analysis_timeout_seconds):
|
|
result = await anyio.to_thread.run_sync(
|
|
analyzer.analyze, analysis_request, on_progress
|
|
)
|
|
except TimeoutError:
|
|
timed_out = True
|
|
|
|
elapsed_ms = (perf_counter() - start) * 1000
|
|
if timed_out:
|
|
final_result = _build_timeout_result(
|
|
expected_tests,
|
|
progress_state["last_index"],
|
|
elapsed_ms,
|
|
)
|
|
else:
|
|
final_result = result or AnalysisResult(
|
|
results=[],
|
|
metadata=ReportMetadata(
|
|
total_tests=len(expected_tests),
|
|
passed_tests=0,
|
|
failed_tests=0,
|
|
skipped_tests=0,
|
|
elapsed_ms=elapsed_ms,
|
|
timed_out=False,
|
|
incomplete_tests=[],
|
|
),
|
|
)
|
|
|
|
await send_stream.send(("result", final_result.model_dump(by_alias=True)))
|
|
await send_stream.send(("done", {}))
|
|
|
|
async def event_stream() -> Any:
|
|
async with anyio.create_task_group() as task_group:
|
|
task_group.start_soon(run_analysis)
|
|
async with receive_stream:
|
|
async for event_type, data in receive_stream:
|
|
if event_type == "done":
|
|
break
|
|
yield _format_sse(event_type, data)
|
|
task_group.cancel_scope.cancel()
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache"},
|
|
)
|