Source code for src.server.lib.log.controller

"""Logging config for the application.

Ensures that the app and uvicorn loggers all log through the queue listener.

Adds a filter for health check route logs.
"""
from __future__ import annotations

import logging
import re
from inspect import isawaitable
from typing import TYPE_CHECKING

import structlog
from litestar.constants import SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.data_extractors import ConnectionDataExtractor, ResponseDataExtractor
from litestar.enums import ScopeType
from litestar.status_codes import (
    HTTP_200_OK,
    HTTP_300_MULTIPLE_CHOICES,
    HTTP_500_INTERNAL_SERVER_ERROR,
)
from litestar.utils.scope import get_litestar_scope_state

from server.lib import settings

__all__ = ["BeforeSendHandler", "drop_health_logs", "middleware_factory"]

if TYPE_CHECKING:
    from typing import Any, Literal

    from litestar.connection import Request
    from litestar.types.asgi_types import ASGIApp, Message, Receive, Scope, Send
    from structlog.types import EventDict, WrappedLogger

LOGGER = structlog.get_logger()

HTTP_RESPONSE_START: Literal["http.response.start"] = "http.response.start"
HTTP_RESPONSE_BODY: Literal["http.response.body"] = "http.response.body"
REQUEST_BODY_FIELD: Literal["body"] = "body"


[docs] def drop_health_logs(_: WrappedLogger, __: str, event_dict: EventDict) -> EventDict: """Prevent logging of successful health checks. Args: _: Wrapped logger object. __: Name of the wrapped method, e.g., "info", "warning", etc. event_dict: Current context with current event, e.g, `{"a": 42, "event": "foo"}`. Returns: `event_dict` for further processing if it does not represent a successful health check. """ is_http_log = event_dict["event"] == settings.log.HTTP_EVENT is_health_log = event_dict.get("request", {}).get("path") == settings.api.HEALTH_PATH is_success_status = HTTP_200_OK <= event_dict.get("response", {}).get("status_code", 0) < HTTP_300_MULTIPLE_CHOICES if is_http_log and is_health_log and is_success_status: raise structlog.DropEvent return event_dict
[docs] def middleware_factory(app: ASGIApp) -> ASGIApp: """Middleware to ensure that every request has a clean structlog context. Args: app: The previous ASGI app in the call chain. Returns: A new ASGI app that cleans the structlog contextvars. """ async def middleware(scope: Scope, receive: Receive, send: Send) -> None: """Clean up structlog contextvars. Args: scope: ASGI connection scope. receive: ASGI receive handler. send: ASGI send handler. """ structlog.contextvars.clear_contextvars() await app(scope, receive, send) return middleware
[docs] class BeforeSendHandler: """Extraction of request and response data from connection scope.""" __slots__ = ( "do_log_request", "do_log_response", "exclude_paths", "include_compressed_body", "logger", "request_extractor", "response_extractor", )
[docs] def __init__(self) -> None: """Configure the handler.""" self.exclude_paths = re.compile(settings.log.EXCLUDE_PATHS) self.do_log_request = bool(settings.log.REQUEST_FIELDS) self.do_log_response = bool(settings.log.RESPONSE_FIELDS) self.include_compressed_body = settings.log.INCLUDE_COMPRESSED_BODY self.request_extractor = ConnectionDataExtractor( extract_body="body" in settings.log.REQUEST_FIELDS, extract_client="client" in settings.log.REQUEST_FIELDS, extract_content_type="content_type" in settings.log.REQUEST_FIELDS, extract_cookies="cookies" in settings.log.REQUEST_FIELDS, extract_headers="headers" in settings.log.REQUEST_FIELDS, extract_method="method" in settings.log.REQUEST_FIELDS, extract_path="path" in settings.log.REQUEST_FIELDS, extract_path_params="path_params" in settings.log.REQUEST_FIELDS, extract_query="query" in settings.log.REQUEST_FIELDS, extract_scheme="scheme" in settings.log.REQUEST_FIELDS, obfuscate_cookies=settings.log.OBFUSCATE_COOKIES, obfuscate_headers=settings.log.OBFUSCATE_HEADERS, parse_body=False, parse_query=False, ) self.response_extractor = ResponseDataExtractor( extract_body="body" in settings.log.RESPONSE_FIELDS, extract_headers="headers" in settings.log.RESPONSE_FIELDS, extract_status_code="status_code" in settings.log.RESPONSE_FIELDS, obfuscate_cookies=settings.log.OBFUSCATE_COOKIES, obfuscate_headers=settings.log.OBFUSCATE_HEADERS, )
[docs] async def __call__(self, message: Message, scope: Scope) -> None: """Receives ASGI response messages and scope, and logs per configuration. Args: message: ASGI response event. scope: ASGI connection scope. """ if scope["type"] == ScopeType.HTTP and self.exclude_paths.findall(scope["path"]): return if message["type"] == HTTP_RESPONSE_START: scope["state"]["log_level"] = ( logging.ERROR if message["status"] >= HTTP_500_INTERNAL_SERVER_ERROR else logging.INFO ) scope["state"][HTTP_RESPONSE_START] = message # ignore intermediate content of streaming responses for now. elif message["type"] == HTTP_RESPONSE_BODY and message["more_body"] is False: scope["state"][HTTP_RESPONSE_BODY] = message try: if self.do_log_request: await self.log_request(scope) if self.do_log_response: await self.log_response(scope) await LOGGER.alog(scope["state"]["log_level"], settings.log.HTTP_EVENT) # RuntimeError: Expected ASGI message 'http.response.body', but got 'http.response.start'. except Exception as exc: # noqa: BLE001 # pylint: disable=broad-except # just in-case something in the context causes the error structlog.contextvars.clear_contextvars() await LOGGER.aerror("Error in logging before-send handler!", exc_info=exc)
[docs] async def log_request(self, scope: Scope) -> None: """Handle extracting the request data and logging the message. Args: scope: The ASGI connection scope. Returns: None """ extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope)) structlog.contextvars.bind_contextvars(request=extracted_data)
[docs] async def log_response(self, scope: Scope) -> None: """Handle extracting the response data and logging the message. Args: scope: The ASGI connection scope. Returns: None """ extracted_data = self.extract_response_data(scope=scope) structlog.contextvars.bind_contextvars(response=extracted_data)
[docs] async def extract_request_data(self, request: Request) -> dict[str, Any]: """Create a dictionary of values for the log. Args: request: A [Request][litestar.connection.request.Request] instance. Returns: An OrderedDict. """ data: dict[str, Any] = {} extracted_data = self.request_extractor(connection=request) missing = object() for key in settings.log.REQUEST_FIELDS: value = extracted_data.get(key, missing) if value is missing: # pragma: no cover continue if isawaitable(value): # Prevent Litestar from raising a RuntimeError # when trying to read an empty request body. try: value = await value except RuntimeError: if key != REQUEST_BODY_FIELD: raise # pragma: no cover value = None data[key] = value return data
[docs] def extract_response_data(self, scope: Scope) -> dict[str, Any]: """Extract data from the response. Args: scope: The ASGI connection scope. Returns: An OrderedDict. """ data: dict[str, Any] = {} extracted_data = self.response_extractor( messages=(scope["state"][HTTP_RESPONSE_START], scope["state"][HTTP_RESPONSE_BODY]), ) missing = object() response_body_compressed = get_litestar_scope_state(scope, SCOPE_STATE_RESPONSE_COMPRESSED) for key in settings.log.RESPONSE_FIELDS: value = extracted_data.get(key, missing) if key == "body" and response_body_compressed and not self.include_compressed_body: continue if value is missing: # pragma: no cover continue data[key] = value return data