"""Project Settings."""
from __future__ import annotations
import base64
import binascii
import os
from pathlib import Path
from typing import Any, Final, Literal
from dotenv import load_dotenv
from litestar.contrib.jinja import JinjaTemplateEngine
from litestar.data_extractors import RequestExtractorField, ResponseExtractorField # noqa: TCH002
from litestar.openapi.spec import Server
from pydantic import ValidationError, field_validator
from pydantic.types import SecretBytes
from pydantic_settings import BaseSettings, SettingsConfigDict
import utils
from __metadata__ import __version__ as version
__all__ = (
"APISettings",
"DatabaseSettings",
"GitHubSettings",
"LogSettings",
"OpenAPISettings",
"ProjectSettings",
"ServerSettings",
"TemplateSettings",
"load_settings",
)
load_dotenv()
DEFAULT_MODULE_NAME = "src"
BASE_DIR: Final = utils.module_to_os_path(DEFAULT_MODULE_NAME)
STATIC_DIR = Path(BASE_DIR / "server" / "domain" / "web" / "resources")
TEMPLATES_DIR = Path(BASE_DIR / "server" / "domain" / "web" / "templates")
[docs]
class ServerSettings(BaseSettings):
"""Server configurations."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="SERVER_", extra="ignore")
APP_LOC: str = "src.app:create_app"
"""Path to app executable, or factory."""
APP_LOC_IS_FACTORY: bool = True
"""Indicate if APP_LOC points to an executable or factory."""
HOST: str = "localhost"
"""Server network host."""
KEEPALIVE: int = 65
"""Seconds to hold connections open."""
PORT: int = 8000
"""Server port."""
RELOAD: bool | None = False
"""Turn on hot reloading."""
RELOAD_DIRS: list[str] = [f"{BASE_DIR}"]
"""Directories to watch for reloading.
.. warning:: This only accepts a single directory for now, something is broken
"""
HTTP_WORKERS: int | None = None
"""Number of HTTP Worker processes to be spawned by Uvicorn."""
[docs]
class ProjectSettings(BaseSettings):
"""Project Settings."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", extra="ignore")
BUILD_NUMBER: str = ""
"""Identifier for CI build."""
CHECK_DB_READY: bool = True
"""Check for database readiness on startup."""
CHECK_REDIS_READY: bool = True
"""Check for redis readiness on startup."""
DEBUG: bool = False
"""Run ``Litestar`` with ``debug=True``."""
ENVIRONMENT: str = "prod"
"""``dev``, ``prod``, ``qa``, etc."""
TEST_ENVIRONMENT_NAME: str = "test"
"""Value of ENVIRONMENT used to determine if running tests.
This should be the value of ``ENVIRONMENT`` in ``tests.env``.
"""
LOCAL_ENVIRONMENT_NAME: str = "local"
"""Value of ENVIRONMENT used to determine if running in local development
mode.
This should be the value of ``ENVIRONMENT`` in your local ``.env`` file.
"""
NAME: str = "Byte Bot"
"""Application name."""
SECRET_KEY: SecretBytes
"""Secret key used for signing cookies and other things."""
JWT_ENCRYPTION_ALGORITHM: str = "HS256"
"""Algorithm used to encrypt JWTs."""
BACKEND_CORS_ORIGINS: list[str] = ["*"]
"""List of origins allowed to access the API."""
STATIC_URL: str = "/static/"
"""Default URL where static assets are located."""
CSRF_COOKIE_NAME: str = "csrftoken"
"""Name of the CSRF cookie."""
CSRF_COOKIE_SECURE: bool = False
"""Set the CSRF cookie to be secure."""
STATIC_DIR: Path = STATIC_DIR
"""Path to static assets."""
DEV_MODE: bool = False
"""Indicate if running in development mode."""
@property
def slug(self) -> str:
"""Return a slugified name.
Returns:
``self.NAME``, all lowercase and hyphens instead of spaces.
"""
return "-".join(s.lower() for s in self.NAME.split())
[docs]
@field_validator("BACKEND_CORS_ORIGINS")
@classmethod
def assemble_cors_origins(
cls,
value: str | list[str] | None,
) -> list[str] | str:
"""Parse a list of origins.
Args:
value: A comma-separated string of origins, or a list of origins.
Returns:
A list of origins.
Raises:
ValueError: If ``value`` is not a list or string.
"""
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, str) and not value.startswith("["):
return [host.strip() for host in value.split(",")]
if isinstance(value, str) and value.startswith("[") and value.endswith("]"):
return list(value)
raise ValueError(value)
[docs]
@field_validator("SECRET_KEY", mode="before")
@classmethod
def generate_secret_key(cls, value: str | None) -> SecretBytes:
"""Generate a secret key.
Args:
value: A secret key, or ``None``.
Returns:
A secret key.
"""
if value is None:
return SecretBytes(binascii.hexlify(os.urandom(32)))
return SecretBytes(value.encode())
[docs]
class APISettings(BaseSettings):
"""API specific configuration."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="API_", extra="ignore")
HEALTH_PATH: str = "/health"
"""Route that the health check is served under."""
OPENCOLLECTIVE_KEY: str | None = None
"""OpenCollective API key."""
OPENCOLLECTIVE_URL: str = "https://api.opencollective.com/graphql/v2"
"""OpenCollective API URL.
.. note:: This is the GraphQL endpoint, the REST endpoint is no longer maintained.
See also: `OpenCollective API Docs <https://graphql-docs-v2.opencollective.com/>`_
"""
POLAR_KEY: str | None = None
"""Polar API key."""
POLAR_URL: str = "https://api.polar.sh"
"""Polar API URL.
.. seealso:: `Polar API Docs <https://api.polar.sh/docs>`_ and
the `Public API #834 Issue <https://github.com/polarsource/polar/issues/834>`_.
"""
[docs]
class LogSettings(BaseSettings):
"""Logging config for the Project."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="LOG_", extra="ignore")
"""https://stackoverflow.com/a/1845097/6560549"""
EXCLUDE_PATHS: str = r"\A(?!x)x"
"""Regex to exclude paths from logging."""
HTTP_EVENT: str = "HTTP"
"""Log event name for logs from ``litestar`` handlers."""
INCLUDE_COMPRESSED_BODY: bool = False
"""Include ``body`` of compressed responses in log output."""
LEVEL: int = 20
"""Stdlib log levels.
Only emit logs at this level, or higher.
"""
OBFUSCATE_COOKIES: set[str] = {"session"}
"""Request cookie keys to obfuscate."""
OBFUSCATE_HEADERS: set[str] = {"Authorization", "X-API-KEY"}
"""Request header keys to obfuscate."""
JOB_FIELDS: list[str] = [
"function",
"kwargs",
"key",
"scheduled",
"attempts",
"completed",
"queued",
"started",
"result",
"error",
]
"""Attributes of the SAQ :class:`Job <saq.job.Job>` to be logged."""
REQUEST_FIELDS: list[RequestExtractorField] = [
"path",
"method",
"headers",
"cookies",
"query",
"path_params",
"body",
]
"""Attributes of the :class:`Request <litestar.connection.request.Request>` to be
logged."""
RESPONSE_FIELDS: list[ResponseExtractorField] = [
"status_code",
"cookies",
"headers",
# "body", # ! We don't want to log the response body.
]
"""Attributes of the :class:`Response <litestar.response.Response>` to be
logged."""
UVICORN_ACCESS_LEVEL: int = 30
"""Level to log uvicorn access logs."""
UVICORN_ERROR_LEVEL: int = 20
"""Level to log uvicorn error logs."""
# noinspection PyUnresolvedReferences
[docs]
class OpenAPISettings(BaseSettings):
"""Configures OpenAPI for the Project."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="OPENAPI_", extra="ignore")
CONTACT_NAME: str = "Admin"
"""Name of contact on document."""
CONTACT_EMAIL: str = "hello@byte-bot.app"
"""Email for contact on document."""
TITLE: str | None = "Byte Bot"
"""Document title."""
VERSION: str = version
"""Document version."""
PATH: str = "/api"
"""Path to access the root API documentation."""
DESCRIPTION: str | None = """The Byte Bot API supports the Byte Discord bot.
You can find out more about this project in the
[docs](https://docs.byte-bot.app/latest)."""
SERVERS: list[dict[str, str]] = []
"""Servers to use for the OpenAPI documentation."""
EXTERNAL_DOCS: dict[str, str] | None = {
"description": "Byte Bot API Docs",
"url": "https://docs.byte-bot.app/latest",
}
"""External documentation for the API."""
[docs]
@field_validator("SERVERS", mode="after")
def assemble_openapi_servers(cls, value: list[Server]) -> list[Server]: # noqa: ARG003
"""Assembles the OpenAPI servers based on the environment.
Args:
value: The value of the SERVERS setting.
Returns:
The assembled OpenAPI servers.
"""
servers = {
"prod": Server(url="https://byte-bot.app/", description="Production"),
"test": Server(url="https://dev.byte-bot.app/", description="Test"),
"dev": Server(url="http://0.0.0.0:8000", description="Development"),
}
environment = os.getenv("ENVIRONMENT", "dev")
if environment == "prod":
return [servers["prod"]]
if environment == "test":
return [servers["test"], servers["prod"]]
return [servers["dev"], servers["test"], servers["prod"]]
[docs]
class TemplateSettings(BaseSettings):
"""Configures Templating for the project."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="TEMPLATE_", extra="ignore")
ENGINE: type[JinjaTemplateEngine] = JinjaTemplateEngine
"""Template engine to use. (Jinja2 or Mako)"""
[docs]
class DatabaseSettings(BaseSettings):
"""Configures the database for the application."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
env_prefix="DB_",
case_sensitive=False,
extra="ignore",
)
ECHO: bool = False
"""Enable SQLAlchemy engine logs."""
ECHO_POOL: bool | Literal["debug"] = False
"""Enable SQLAlchemy connection pool logs."""
POOL_DISABLE: bool = True
"""Disable SQLAlchemy pooling, same as setting pool to.
See :class:`NullPool <sqlalchemy.pool.NullPool>`.
"""
POOL_MAX_OVERFLOW: int = 10
"""See :class:`QueuePool <sqlalchemy.pool.QueuePool>`.
.. warning:: This is arguably pretty high,
and shouldn't be raised past 10.
"""
POOL_SIZE: int = 5
"""See :class:`QueuePool <sqlalchemy.pool.QueuePool>`."""
POOL_TIMEOUT: int = 30
"""See :class:`QueuePool <sqlalchemy.pool.QueuePool>`."""
POOL_RECYCLE: int = 300
"""See :class:`QueuePool <sqlalchemy.pool.QueuePool>`."""
POOL_PRE_PING: bool = False
"""See :class:`QueuePool <sqlalchemy.pool.QueuePool>`."""
CONNECT_ARGS: dict[str, Any] = {}
"""Connection arguments to pass to the database driver."""
URL: str = "postgresql+asyncpg://byte:bot@localhost:5432/byte"
"""Database connection URL."""
ENGINE: str | None = None
"""Database engine."""
USER: str = "byte"
"""Database user."""
PASSWORD: str = "bot"
"""Database password."""
HOST: str = "localhost"
"""Database host."""
PORT: int = 5432
"""Database port."""
NAME: str = "byte"
"""Database name."""
MIGRATION_CONFIG: str = f"{BASE_DIR}/server/lib/db/alembic.ini"
"""Path to Alembic config file."""
MIGRATION_PATH: str = f"{BASE_DIR}/server/lib/db/migrations"
"""Path to Alembic migration files."""
MIGRATION_DDL_VERSION_TABLE: str = "ddl_version"
"""Name of the table used to track DDL version."""
[docs]
class GitHubSettings(BaseSettings):
"""Configures GitHub app for the project."""
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", env_prefix="GITHUB_", extra="ignore")
NAME: str = "byte-bot-app"
"""GitHub App name."""
APP_ID: int = 480575
"""GitHub App ID."""
APP_PRIVATE_KEY: str = ""
"""GitHub App private key."""
APP_CLIENT_ID: str = "Iv1.c3a5214c6642dedd"
"""GitHub App client ID."""
APP_CLIENT_SECRET: str = ""
"""GitHub App client secret."""
REDIRECT_URL: str = "http://127.0.0.1:3000/github/session"
"""GitHub App redirect URL."""
PERSONAL_ACCESS_TOKEN: str | None = None
"""GitHub personal access token."""
[docs]
@field_validator("APP_PRIVATE_KEY", mode="before")
def validate_and_load_private_key(cls, value: str) -> str:
"""Validates and loads the GitHub App private key.
Args:
value: The value of the APP_PRIVATE_KEY setting.
Returns:
The validated and loaded GitHub App private key.
"""
try:
decoded_key = base64.b64decode(value).decode("utf-8")
except binascii.Error as e:
environment = os.getenv("ENVIRONMENT", "dev")
if environment != "dev":
msg = "The GitHub private key must be a valid base64 encoded string"
raise ValueError(msg) from e
key_path = Path(BASE_DIR).parent / value
if key_path.is_file():
return key_path.read_text()
msg = f"Private key file not found at {key_path}"
raise ValueError(msg) from e
# if not decoded_key.startswith("-----BEGIN RSA PRIVATE KEY-----") or not decoded_key.endswith(
# "-----END RSA PRIVATE KEY-----"):
# msg = "The GitHub private key must be a valid RSA key"
# raise ValueError(msg)
return decoded_key
# noinspection PyShadowingNames
[docs]
def load_settings() -> (
tuple[
ProjectSettings,
APISettings,
OpenAPISettings,
TemplateSettings,
ServerSettings,
LogSettings,
DatabaseSettings,
GitHubSettings,
]
):
"""Load Settings file.
Returns:
Settings: application settings
"""
try:
"""Override Application reload dir."""
server: ServerSettings = ServerSettings.model_validate(
{"HOST": "0.0.0.0", "RELOAD_DIRS": [str(BASE_DIR)]}, # noqa: S104
)
project: ProjectSettings = ProjectSettings.model_validate({})
api: APISettings = APISettings.model_validate({})
openapi: OpenAPISettings = OpenAPISettings.model_validate({})
template: TemplateSettings = TemplateSettings.model_validate({})
log: LogSettings = LogSettings.model_validate({})
database: DatabaseSettings = DatabaseSettings.model_validate({})
github: GitHubSettings = GitHubSettings.model_validate({})
except ValidationError as error:
print(f"Could not load settings. Error: {error!r}") # noqa: T201
raise error from error
return (
project,
api,
openapi,
template,
server,
log,
database,
github,
)
(
project,
api,
openapi,
template,
server,
log,
db,
github,
) = load_settings()