from __future__ import annotations
import re
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Literal, Pattern
from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
__all__ = ("CORSConfig",)
if TYPE_CHECKING:
from litestar.types import Method
@dataclass
class CORSConfig:
"""Configuration for CORS (Cross-Origin Resource Sharing).
To enable CORS, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` constructor using the
'cors_config' key.
"""
allow_origins: list[str] = field(default_factory=lambda: ["*"])
"""List of origins that are allowed.
Can use '*' in any component of the path, e.g. 'domain.*'. Sets the 'Access-Control-Allow-Origin' header.
"""
allow_methods: list[Literal["*"] | Method] = field(default_factory=lambda: ["*"])
"""List of allowed HTTP methods.
Sets the 'Access-Control-Allow-Methods' header.
"""
allow_headers: list[str] = field(default_factory=lambda: ["*"])
"""List of allowed headers.
Sets the 'Access-Control-Allow-Headers' header.
"""
allow_credentials: bool = field(default=False)
"""Boolean dictating whether or not to set the 'Access-Control-Allow-Credentials' header."""
allow_origin_regex: str | None = field(default=None)
"""Regex to match origins against."""
expose_headers: list[str] = field(default_factory=list)
"""List of headers that are exposed via the 'Access-Control-Expose-Headers' header."""
max_age: int = field(default=600)
"""Response caching TTL in seconds, defaults to 600.
Sets the 'Access-Control-Max-Age' header.
"""
def __post_init__(self) -> None:
self.allow_headers = [v.lower() for v in self.allow_headers]
@cached_property
def allowed_origins_regex(self) -> Pattern[str]:
"""Get or create a compiled regex for allowed origins.
Returns:
A compiled regex of the allowed path.
"""
origins = self.allow_origins
if self.allow_origin_regex:
origins.append(self.allow_origin_regex)
return re.compile("|".join([origin.replace("*.", r".*\.") for origin in origins]))
@cached_property
def is_allow_all_origins(self) -> bool:
"""Get a cached boolean flag dictating whether all origins are allowed.
Returns:
Boolean dictating whether all origins are allowed.
"""
return "*" in self.allow_origins
@cached_property
def is_allow_all_methods(self) -> bool:
"""Get a cached boolean flag dictating whether all methods are allowed.
Returns:
Boolean dictating whether all methods are allowed.
"""
return "*" in self.allow_methods
@cached_property
def is_allow_all_headers(self) -> bool:
"""Get a cached boolean flag dictating whether all headers are allowed.
Returns:
Boolean dictating whether all headers are allowed.
"""
return "*" in self.allow_headers
@cached_property
def preflight_headers(self) -> dict[str, str]:
"""Get cached pre-flight headers.
Returns:
A dictionary of headers to set on the response object.
"""
headers: dict[str, str] = {"Access-Control-Max-Age": str(self.max_age)}
if self.is_allow_all_origins:
headers["Access-Control-Allow-Origin"] = "*"
else:
headers["Vary"] = "Origin"
if self.allow_credentials:
headers["Access-Control-Allow-Credentials"] = str(self.allow_credentials).lower()
if not self.is_allow_all_headers:
headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(self.allow_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
if self.allow_methods:
headers["Access-Control-Allow-Methods"] = ", ".join(
sorted(
{"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"}
if self.is_allow_all_methods
else set(self.allow_methods)
)
)
return headers
@cached_property
def simple_headers(self) -> dict[str, str]:
"""Get cached simple headers.
Returns:
A dictionary of headers to set on the response object.
"""
simple_headers = {}
if self.is_allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if self.allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if self.expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(sorted(set(self.expose_headers)))
return simple_headers
def is_origin_allowed(self, origin: str) -> bool:
"""Check whether a given origin is allowed.
Args:
origin: An origin header value.
Returns:
Boolean determining whether an origin is allowed.
"""
return bool(self.is_allow_all_origins or self.allowed_origins_regex.fullmatch(origin))