mirror of
https://github.com/mgeeky/decode-spam-headers.git
synced 2026-02-22 05:23:31 +01:00
MAESTRO: add analysis SSE router
This commit is contained in:
190
backend/app/routers/analysis.py
Normal file
190
backend/app/routers/analysis.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, ClosedResourceError
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.engine.analyzer import HeaderAnalyzer
|
||||
from app.engine.models import AnalysisResult, ReportMetadata
|
||||
from app.engine.scanner_registry import ScannerRegistry
|
||||
from app.schemas.analysis import AnalysisProgress, AnalysisRequest
|
||||
|
||||
MAX_HEADERS_LENGTH = 1_048_576
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["analysis"])
|
||||
|
||||
|
||||
def _sanitize_headers(value: str) -> str:
|
||||
if not value:
|
||||
return value
|
||||
cleaned = value.replace("\x00", "")
|
||||
cleaned = "".join(
|
||||
ch for ch in cleaned if ch in ("\n", "\r", "\t") or ch.isprintable()
|
||||
)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _format_sse(event: str, data: object) -> str:
|
||||
payload = json.dumps(data)
|
||||
return f"event: {event}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _build_timeout_result(
|
||||
expected_tests: list[str],
|
||||
last_index: int | None,
|
||||
elapsed_ms: float,
|
||||
) -> AnalysisResult:
|
||||
total_tests = len(expected_tests)
|
||||
if last_index is None:
|
||||
incomplete = list(expected_tests)
|
||||
else:
|
||||
incomplete = expected_tests[last_index:]
|
||||
|
||||
metadata = ReportMetadata(
|
||||
total_tests=total_tests,
|
||||
passed_tests=0,
|
||||
failed_tests=0,
|
||||
skipped_tests=0,
|
||||
elapsed_ms=elapsed_ms,
|
||||
timed_out=True,
|
||||
incomplete_tests=incomplete,
|
||||
)
|
||||
return AnalysisResult(results=[], metadata=metadata)
|
||||
|
||||
|
||||
@router.post("/analyse")
|
||||
async def analyse(request: Request) -> StreamingResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload") from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(status_code=400, detail="Invalid request payload")
|
||||
|
||||
headers = payload.get("headers")
|
||||
if not isinstance(headers, str):
|
||||
raise HTTPException(status_code=400, detail="Headers must be a string")
|
||||
|
||||
if not headers.strip():
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Headers cannot be empty"}
|
||||
)
|
||||
|
||||
if len(headers) > MAX_HEADERS_LENGTH:
|
||||
return JSONResponse(
|
||||
status_code=413, content={"error": "Headers exceed 1 MB"}
|
||||
)
|
||||
|
||||
headers = _sanitize_headers(headers)
|
||||
if not headers.strip():
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Headers cannot be empty"}
|
||||
)
|
||||
|
||||
config = payload.get("config") or {}
|
||||
try:
|
||||
analysis_request = AnalysisRequest.model_validate(
|
||||
{"headers": headers, "config": config}
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid analysis request", "detail": exc.errors()},
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
analyzer = HeaderAnalyzer()
|
||||
registry = ScannerRegistry()
|
||||
if analysis_request.config.test_ids:
|
||||
scanners = registry.get_by_ids(analysis_request.config.test_ids)
|
||||
else:
|
||||
scanners = registry.get_all()
|
||||
expected_tests = [scanner.name for scanner in scanners]
|
||||
|
||||
start = perf_counter()
|
||||
progress_state: dict[str, Any] = {"last_index": None}
|
||||
send_stream, receive_stream = anyio.create_memory_object_stream[
|
||||
tuple[str, dict[str, Any]]
|
||||
](max_buffer_size=0)
|
||||
|
||||
def on_progress(current_index: int, total_tests: int, test_name: str) -> None:
|
||||
progress_state["last_index"] = current_index
|
||||
elapsed_ms = (perf_counter() - start) * 1000
|
||||
percentage = 0.0
|
||||
if total_tests > 0:
|
||||
percentage = min(
|
||||
100.0, (current_index + 1) / total_tests * 100.0
|
||||
)
|
||||
payload = AnalysisProgress(
|
||||
current_index=current_index,
|
||||
total_tests=total_tests,
|
||||
current_test=test_name,
|
||||
elapsed_ms=elapsed_ms,
|
||||
percentage=percentage,
|
||||
).model_dump(by_alias=True)
|
||||
try:
|
||||
anyio.from_thread.run(send_stream.send, ("progress", payload))
|
||||
except (BrokenResourceError, ClosedResourceError, RuntimeError):
|
||||
pass
|
||||
|
||||
async def run_analysis() -> None:
|
||||
result: AnalysisResult | None = None
|
||||
timed_out = False
|
||||
async with send_stream:
|
||||
try:
|
||||
with anyio.fail_after(settings.analysis_timeout_seconds):
|
||||
result = await anyio.to_thread.run_sync(
|
||||
analyzer.analyze, analysis_request, on_progress
|
||||
)
|
||||
except TimeoutError:
|
||||
timed_out = True
|
||||
|
||||
elapsed_ms = (perf_counter() - start) * 1000
|
||||
if timed_out:
|
||||
final_result = _build_timeout_result(
|
||||
expected_tests,
|
||||
progress_state["last_index"],
|
||||
elapsed_ms,
|
||||
)
|
||||
else:
|
||||
final_result = result or AnalysisResult(
|
||||
results=[],
|
||||
metadata=ReportMetadata(
|
||||
total_tests=len(expected_tests),
|
||||
passed_tests=0,
|
||||
failed_tests=0,
|
||||
skipped_tests=0,
|
||||
elapsed_ms=elapsed_ms,
|
||||
timed_out=False,
|
||||
incomplete_tests=[],
|
||||
),
|
||||
)
|
||||
|
||||
await send_stream.send(
|
||||
("result", final_result.model_dump(by_alias=True))
|
||||
)
|
||||
await send_stream.send(("done", {}))
|
||||
|
||||
async def event_stream() -> Any:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
task_group.start_soon(run_analysis)
|
||||
async with receive_stream:
|
||||
async for event_type, data in receive_stream:
|
||||
if event_type == "done":
|
||||
break
|
||||
yield _format_sse(event_type, data)
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
Reference in New Issue
Block a user