From 0d176b44f176a56cfc88b3735d3712968fc14bcd Mon Sep 17 00:00:00 2001 From: Naksen Date: Tue, 27 Jan 2026 18:45:41 +0300 Subject: [PATCH 01/23] add: rw database routing --- app/config.py | 37 +++++++++++-------- app/database.py | 85 +++++++++++++++++++++++++++++++++++++++++++ app/ioc.py | 22 ++++++----- app/multidirectory.py | 8 ++-- 4 files changed, 124 insertions(+), 28 deletions(-) create mode 100644 app/database.py diff --git a/app/config.py b/app/config.py index 423eb2bf8..1e406f993 100644 --- a/app/config.py +++ b/app/config.py @@ -22,7 +22,6 @@ computed_field, field_validator, ) -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine def _get_vendor_version() -> str: @@ -49,12 +48,20 @@ class Settings(BaseModel): TCP_PACKET_SIZE: int = 1024 COROUTINES_NUM_PER_CLIENT: int = 3 + POSTGRES_RW_MODE: Literal["single", "master_replica"] = "single" POSTGRES_SCHEMA: ClassVar[str] = "postgresql+psycopg" - POSTGRES_DB: str = "postgres" + POSTGRES_REPLICA_DB: str = "" + POSTGRES_REPLICA_HOST: str = "" + POSTGRES_REPLICA_USER: str = "" + POSTGRES_REPLICA_PASSWORD: str = "" + POSTGRES_REPLICA_CONNECT_TIMEOUT: int = 4 + + POSTGRES_DB: str = "postgres" POSTGRES_HOST: str = "postgres" POSTGRES_USER: str POSTGRES_PASSWORD: str + POSTGRES_CONNECT_TIMEOUT: int = 4 SESSION_STORAGE_URL: RedisDsn = RedisDsn("redis://dragonfly:6379/1") SESSION_KEY_LENGTH: int = 16 @@ -99,6 +106,18 @@ def POSTGRES_URI(self) -> PostgresDsn: # noqa f"{self.POSTGRES_DB}", ) + @computed_field # type: ignore + @cached_property + def REPLICA_POSTGRES_URI(self) -> PostgresDsn: # noqa + """Build replica postgres DSN.""" + return PostgresDsn( + f"{self.POSTGRES_SCHEMA}://" + f"{self.POSTGRES_REPLICA_USER}:" + f"{self.POSTGRES_REPLICA_PASSWORD}@" + f"{self.POSTGRES_REPLICA_HOST}/" + f"{self.POSTGRES_REPLICA_DB}", + ) + VENDOR_NAME: ClassVar[str] = "MultiFactor" VENDOR_VERSION: str = Field( default_factory=_get_vendor_version, @@ -220,20 +239,6 @@ def check_certs_exist(self) -> bool: """Check if certs exist.""" return os.path.exists(self.SSL_CERT) and os.path.exists(self.SSL_KEY) - @cached_property - def engine(self) -> AsyncEngine: - """Get engine.""" - return create_async_engine( - str(self.POSTGRES_URI), - pool_size=self.INSTANCE_DB_POOL_SIZE, - max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, - pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, - pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, - pool_pre_ping=False, - future=True, - echo=False, - ) - @classmethod def from_os(cls) -> "Settings": """Get cls from environ.""" diff --git a/app/database.py b/app/database.py new file mode 100644 index 000000000..4f23a243f --- /dev/null +++ b/app/database.py @@ -0,0 +1,85 @@ +"""Database configuration and routing session. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Any, Sequence + +from loguru import logger +from sqlalchemy import Delete, Insert, Update, exc as sa_exc +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import Session + +from config import Settings + +settings = Settings.from_os() + +engines = { + "master": create_async_engine( + str(settings.POSTGRES_URI), + pool_size=settings.INSTANCE_DB_POOL_SIZE, + max_overflow=settings.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=settings.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=settings.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="master", + connect_args={"connect_timeout": settings.POSTGRES_CONNECT_TIMEOUT}, + ), +} +if settings.POSTGRES_RW_MODE == "master_replica": + engines["replica"] = create_async_engine( + str(settings.REPLICA_POSTGRES_URI), + pool_size=settings.INSTANCE_DB_POOL_SIZE, + max_overflow=settings.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=settings.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=settings.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="replica", + connect_args={ + "connect_timeout": settings.POSTGRES_REPLICA_CONNECT_TIMEOUT, + }, + ) + + +class RoutingSession(Session): + _force_master: bool = False + + @property + def force_master(self) -> bool: + return self._force_master + + def set_force_master(self, value: bool) -> None: + self._force_master = value + + def get_bind(self, mapper=None, clause=None) -> Engine: # type: ignore # noqa: ARG002 + logger.critical("-- CALL RoutingSession.get_bind --") + + if isinstance(clause, Update | Insert | Delete): + logger.critical("MASTER") + return engines["master"].sync_engine + + if self._force_master or self._flushing: + logger.critical("MASTER") + return engines["master"].sync_engine + else: + logger.critical("REPLICA") + return engines["replica"].sync_engine + + def flush(self, objects: Sequence[Any] | None = None) -> None: + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + self._force_master = True diff --git a/app/ioc.py b/app/ioc.py index d6489f842..2a14c7112 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -8,12 +8,12 @@ import httpx import redis.asyncio as redis +from database import RoutingSession, engines from dishka import Provider, Scope, from_context, provide from fastapi import Request from loguru import logger from sqlalchemy.ext.asyncio import ( AsyncConnection, - AsyncEngine, AsyncSession, async_sessionmaker, ) @@ -162,18 +162,22 @@ class MainProvider(Provider): scope = Scope.APP settings = from_context(provides=Settings, scope=Scope.APP) - @provide(scope=Scope.APP) - def get_engine(self, settings: Settings) -> AsyncEngine: - """Get async engine.""" - return settings.engine - @provide(scope=Scope.APP) def get_session_factory( self, - engine: AsyncEngine, + settings: Settings, ) -> async_sessionmaker[AsyncSession]: """Create session factory.""" - return async_sessionmaker(engine, expire_on_commit=False) + if settings.POSTGRES_RW_MODE == "single": + return async_sessionmaker( + bind=engines["master"], + expire_on_commit=False, + ) + + return async_sessionmaker( + sync_session_class=RoutingSession, + expire_on_commit=False, + ) @provide(scope=Scope.REQUEST) async def create_session( @@ -895,8 +899,8 @@ def get_session_factory( @provide(scope=Scope.APP) async def get_conn_factory( self, - engine: AsyncEngine, ) -> AsyncIterator[AsyncConnection]: """Create session factory.""" + engine = engines["master"] async with engine.connect() as connection: yield connection diff --git a/app/multidirectory.py b/app/multidirectory.py index 22a19259d..e6c048c34 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -13,6 +13,7 @@ import uvicorn import uvloop from alembic.config import Config, command +from database import engines from dishka import Scope, make_async_container from dishka.integrations.fastapi import setup_dishka from fastapi import FastAPI @@ -118,7 +119,7 @@ def _create_shadow_app(settings: Settings) -> FastAPI: return app -def _add_app_sqlalchemy_debugger(app: FastAPI, settings: Settings) -> None: +def _add_app_sqlalchemy_debugger(app: FastAPI) -> None: try: import json from dataclasses import asdict @@ -138,7 +139,7 @@ def handle(self, statistics: AlchemyStatistics) -> None: app.add_middleware( SQLAlchemyMonitor, - engine=settings.engine, + engine=engines["master"], actions=[JsonPrintStatistics()], ) @@ -149,6 +150,7 @@ def create_prod_app( ) -> FastAPI: """Create production app with container.""" settings = settings or Settings.from_os() + app = factory(settings) container = make_async_container( MainProvider(), @@ -159,7 +161,7 @@ def create_prod_app( ) if settings.ENABLE_SQLALCHEMY_LOGGING: - _add_app_sqlalchemy_debugger(app, settings) + _add_app_sqlalchemy_debugger(app) setup_dishka(container, app) return app From 446d4ed04bdf2833aecb8e4009a2c097545d9f65 Mon Sep 17 00:00:00 2001 From: Naksen Date: Tue, 27 Jan 2026 18:46:49 +0300 Subject: [PATCH 02/23] add: define RESPONSE_TYPE for various LDAP request classes --- app/ldap_protocol/ldap_requests/abandon.py | 1 + app/ldap_protocol/ldap_requests/add.py | 1 + app/ldap_protocol/ldap_requests/base.py | 26 +++++++++++++++++--- app/ldap_protocol/ldap_requests/bind.py | 7 +++++- app/ldap_protocol/ldap_requests/delete.py | 1 + app/ldap_protocol/ldap_requests/extended.py | 1 + app/ldap_protocol/ldap_requests/modify.py | 1 + app/ldap_protocol/ldap_requests/modify_dn.py | 1 + app/ldap_protocol/ldap_requests/search.py | 1 + 9 files changed, 35 insertions(+), 5 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index b9569ca0e..4f975bc0e 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -17,6 +17,7 @@ class AbandonRequest(BaseRequest): """Abandon protocol.""" + RESPONSE_TYPE: ClassVar[type] = type(None) CONTEXT_TYPE: ClassVar[type] = LDAPAbandonRequestContext PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ABANDON message_id: int diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 3747f7b64..f3169ea13 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -64,6 +64,7 @@ class AddRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = AddResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 445ce3bae..b6e6ddde4 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -18,11 +18,13 @@ from dishka import AsyncContainer from loguru import logger from pydantic import BaseModel +from sqlalchemy.exc import OperationalError from config import Settings from entities import Directory from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession +from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_responses import BaseResponse, LDAPResult from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase @@ -63,6 +65,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + RESPONSE_TYPE: ClassVar[type] CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] @@ -118,9 +121,16 @@ async def handle_tcp( ctx = await container.get(self.CONTEXT_TYPE) # type: ignore responses = [] - async for response in self.handle(ctx=ctx): - responses.append(response) - yield response + try: + async for response in self.handle(ctx=ctx): + responses.append(response) + yield response + except OperationalError: + yield self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ) + return if self.PROTOCOL_OP != ProtocolRequests.SEARCH: ldap_session = await container.get(LDAPSession) @@ -172,7 +182,15 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(ctx=ctx)] + try: + responses = [response async for response in self.handle(ctx=ctx)] + except OperationalError: + responses = [ + self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ), + ] if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index ad764649e..7c1a1f150 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from pydantic import Field +from sqlalchemy.exc import OperationalError from entities import NetworkPolicy from enums import MFAFlags @@ -42,6 +43,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" + RESPONSE_TYPE: ClassVar[type] = BindResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext @@ -215,7 +217,10 @@ async def handle( ) await ctx.ldap_session.set_user(user) - await set_user_logon_attrs(user, ctx.session, ctx.settings.TIMEZONE) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, ctx.session, ctx.settings.TIMEZONE, + ) server_sasl_creds = None if isinstance(self.authentication_choice, SaslSPNEGOAuthentication): diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index 401a5b98f..70062918b 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -42,6 +42,7 @@ class DeleteRequest(BaseRequest): DelRequest ::= [APPLICATION 10] LDAPDN """ + RESPONSE_TYPE: ClassVar[type] = DeleteResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index 1f8cca946..a3e74ad28 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -307,6 +307,7 @@ class ExtendedRequest(BaseRequest): requestValue [1] OCTET STRING OPTIONAL } """ + RESPONSE_TYPE: ClassVar[type] = ExtendedResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 2abdc1d39..9f5b8789d 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -102,6 +102,7 @@ class ModifyRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = ModifyResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index cdf03ab7b..dc4421d49 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -67,6 +67,7 @@ class ModifyDNRequest(BaseRequest): >>> cn = main2, cn = Users, dc = multifactor, dc = dev """ + RESPONSE_TYPE: ClassVar[type] = ModifyDNResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index c6505322a..d5ff4e679 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -104,6 +104,7 @@ class SearchRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = SearchResultDone PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext From b277eff2828af5e6414e41d82255799a0e24afdd Mon Sep 17 00:00:00 2001 From: Naksen Date: Tue, 27 Jan 2026 18:47:00 +0300 Subject: [PATCH 03/23] add: handle OperationalError when setting user logon attributes --- app/ldap_protocol/session_storage/repository.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/app/ldap_protocol/session_storage/repository.py b/app/ldap_protocol/session_storage/repository.py index 84366faee..fb9e860c5 100644 --- a/app/ldap_protocol/session_storage/repository.py +++ b/app/ldap_protocol/session_storage/repository.py @@ -1,9 +1,11 @@ """Enterprise Session Repository.""" +import contextlib from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address from typing import ClassVar, Literal +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from abstract_service import AbstractService @@ -87,8 +89,11 @@ async def create_session_key( }, ttl=ttl, ) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, self.session, self.settings.TIMEZONE, + ) - await set_user_logon_attrs(user, self.session, self.settings.TIMEZONE) return key async def get_user_sessions( From 6aaa89fc76bfa504b2002a47f0d977e1b6533432 Mon Sep 17 00:00:00 2001 From: Naksen Date: Tue, 27 Jan 2026 18:47:23 +0300 Subject: [PATCH 04/23] add: implement master database check utility and apply it across various routers --- app/api/audit/router.py | 26 ++++++++-- app/api/auth/router_auth.py | 4 +- app/api/auth/router_mfa.py | 8 +-- app/api/ldap_schema/attribute_type_router.py | 6 ++- app/api/ldap_schema/entity_type_router.py | 11 ++++- app/api/ldap_schema/object_class_router.py | 11 ++++- app/api/main/dns_router.py | 7 ++- app/api/main/krb5_router.py | 18 ++++--- app/api/main/router.py | 49 +++++++++++++++---- app/api/network/router.py | 21 ++++++-- .../password_ban_word_router.py | 2 + .../password_policy/password_policy_router.py | 13 ++++- .../user_password_history_router.py | 3 +- app/api/shadow/router.py | 9 +++- app/api/utils.py | 35 +++++++++++++ 15 files changed, 186 insertions(+), 37 deletions(-) create mode 100644 app/api/utils.py diff --git a/app/api/audit/router.py b/app/api/audit/router.py index 4a328e2ef..bf4ef1912 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -15,6 +15,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.policies.audit.exception import ( AuditAlreadyExistsError, @@ -59,7 +60,11 @@ async def get_audit_policies( return await audit_adapter.get_policies() -@audit_router.put("/policy/{policy_id}", error_map=error_map) +@audit_router.put( + "/policy/{policy_id}", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def update_audit_policy( policy_id: int, policy_data: AuditPolicySchemaRequest, @@ -69,7 +74,11 @@ async def update_audit_policy( return await audit_adapter.update_policy(policy_id, policy_data) -@audit_router.get("/destinations", error_map=error_map) +@audit_router.get( + "/destinations", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def get_audit_destinations( audit_adapter: FromDishka[AuditPoliciesAdapter], ) -> list[AuditDestinationResponse]: @@ -81,6 +90,7 @@ async def get_audit_destinations( "/destination", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def create_audit_destination( destination_data: AuditDestinationSchemaRequest, @@ -90,7 +100,11 @@ async def create_audit_destination( return await audit_adapter.create_destination(destination_data) -@audit_router.delete("/destination/{destination_id}", error_map=error_map) +@audit_router.delete( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def delete_audit_destination( destination_id: int, audit_adapter: FromDishka[AuditPoliciesAdapter], @@ -99,7 +113,11 @@ async def delete_audit_destination( await audit_adapter.delete_destination(destination_id) -@audit_router.put("/destination/{destination_id}", error_map=error_map) +@audit_router.put( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def update_audit_destination( destination_id: int, destination_data: AuditDestinationSchemaRequest, diff --git a/app/api/auth/router_auth.py b/app/api/auth/router_auth.py index 56484a88f..75dbb019e 100644 --- a/app/api/auth/router_auth.py +++ b/app/api/auth/router_auth.py @@ -19,6 +19,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( MFAAPIError, @@ -186,7 +187,7 @@ async def logout( @auth_router.patch( "/user/password", status_code=200, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def password_reset( @@ -229,6 +230,7 @@ async def check_setup( status_code=status.HTTP_200_OK, responses={423: {"detail": "Locked"}}, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def first_setup( request: SetupRequest, diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index d7a90b3b2..c0915de21 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -24,6 +24,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( ForbiddenError, @@ -81,7 +82,7 @@ @mfa_router.post( "/setup", status_code=status.HTTP_201_CREATED, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def setup_mfa( @@ -100,7 +101,7 @@ async def setup_mfa( @mfa_router.delete( "/keys", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def remove_mfa( @@ -113,7 +114,7 @@ async def remove_mfa( @mfa_router.post( "/get", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def get_mfa( @@ -134,6 +135,7 @@ async def get_mfa( name="callback_mfa", include_in_schema=True, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def callback_mfa( access_token: Annotated[ diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index 5a2f1f368..503f3654e 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map, ldap_schema_router from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter @@ -16,6 +16,7 @@ AttributeTypeSchema, AttributeTypeUpdateSchema, ) +from api.utils import check_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -23,6 +24,7 @@ "/attribute_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def create_one_attribute_type( request_data: AttributeTypeSchema[None], @@ -59,6 +61,7 @@ async def get_list_attribute_types_with_pagination( @ldap_schema_router.patch( "/attribute_type/{attribute_type_name}", error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def modify_one_attribute_type( attribute_type_name: str, @@ -72,6 +75,7 @@ async def modify_one_attribute_type( @ldap_schema_router.post( "/attribute_types/delete", error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def delete_bulk_attribute_types( attribute_types_names: LimitedListType, diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index 31de91616..c4bf1d85a 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter @@ -17,6 +17,7 @@ EntityTypeSchema, EntityTypeUpdateSchema, ) +from api.utils import check_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/entity_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def create_one_entity_type( request_data: EntityTypeSchema[None], @@ -66,6 +68,7 @@ async def get_entity_type_attributes( @ldap_schema_router.patch( "/entity_type/{entity_type_name}", error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def modify_one_entity_type( entity_type_name: str, @@ -76,7 +79,11 @@ async def modify_one_entity_type( await adapter.update(name=entity_type_name, data=request_data) -@ldap_schema_router.post("/entity_type/delete", error_map=error_map) +@ldap_schema_router.post( + "/entity_type/delete", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def delete_bulk_entity_types( entity_type_names: LimitedListType, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index a351f3b33..c4bc8d44a 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter @@ -17,6 +17,7 @@ ObjectClassSchema, ObjectClassUpdateSchema, ) +from api.utils import check_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/object_class", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def create_one_object_class( request_data: ObjectClassSchema[None], @@ -57,6 +59,7 @@ async def get_list_object_classes_with_pagination( @ldap_schema_router.patch( "/object_class/{object_class_name}", error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def modify_one_object_class( object_class_name: str, @@ -67,7 +70,11 @@ async def modify_one_object_class( await adapter.update(object_class_name, request_data) -@ldap_schema_router.post("/object_class/delete", error_map=error_map) +@ldap_schema_router.post( + "/object_class/delete", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def delete_bulk_object_classes( object_classes_names: LimitedListType, adapter: FromDishka[ObjectClassFastAPIAdapter], diff --git a/app/api/main/dns_router.py b/app/api/main/dns_router.py index 509cb377a..099187337 100644 --- a/app/api/main/dns_router.py +++ b/app/api/main/dns_router.py @@ -29,6 +29,7 @@ DNSServiceZoneDeleteRequest, DNSServiceZoneUpdateRequest, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.dns import ( DNSForwardServerStatus, @@ -139,7 +140,11 @@ async def get_dns_status( return await adapter.get_dns_status() -@dns_router.post("/setup", error_map=error_map) +@dns_router.post( + "/setup", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def setup_dns( data: DNSServiceSetupRequest, adapter: FromDishka[DNSFastAPIAdapter], diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index 91f64a5b6..a52858eb3 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -24,6 +24,7 @@ ) from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.main.schema import KerberosSetupRequest +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import KerberosState @@ -82,7 +83,7 @@ "/setup/tree", response_class=Response, error_map=error_map, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], ) async def setup_krb_catalogue( mail: Annotated[EmailStr, Body()], @@ -106,7 +107,12 @@ async def setup_krb_catalogue( ) -@krb5_router.post("/setup", response_class=Response, error_map=error_map) +@krb5_router.post( + "/setup", + response_class=Response, + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def setup_kdc( data: KerberosSetupRequest, identity_adapter: FromDishka[AuthFastAPIAdapter], @@ -173,7 +179,7 @@ async def get_krb_status( @krb5_router.post( "/principal/add", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def add_principal( @@ -193,7 +199,7 @@ async def add_principal( @krb5_router.patch( "/principal/rename", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def rename_principal( @@ -217,7 +223,7 @@ async def rename_principal( @krb5_router.patch( "/principal/reset", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def reset_principal_pw( @@ -238,7 +244,7 @@ async def reset_principal_pw( @krb5_router.delete( "/principal/delete", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], error_map=error_map, ) async def delete_principal( diff --git a/app/api/main/router.py b/app/api/main/router.py index 59250708b..d3bdd1d2f 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -16,6 +16,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.custom_requests.rename import RenameRequest from ldap_protocol.identity.exceptions import UnauthorizedError @@ -69,19 +70,31 @@ async def search(request: SearchRequest, req: Request) -> SearchResponse: ) -@entry_router.post("/add", error_map=error_map) +@entry_router.post( + "/add", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def add(request: AddRequest, req: Request) -> LDAPResult: """LDAP ADD entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update", error_map=error_map) +@entry_router.patch( + "/update", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def modify(request: ModifyRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update_many", error_map=error_map) +@entry_router.patch( + "/update_many", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def modify_many( requests: list[ModifyRequest], req: Request, @@ -93,13 +106,21 @@ async def modify_many( return results -@entry_router.put("/update/dn", error_map=error_map) +@entry_router.put( + "/update/dn", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def modify_dn(request: ModifyDNRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry DN request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.post("/update_many/dn", error_map=error_map) +@entry_router.post( + "/update_many/dn", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def modify_dn_many( requests: list[ModifyDNRequest], req: Request, @@ -116,14 +137,21 @@ async def rename(request: RenameRequest, req: Request) -> LDAPResult: """LDAP rename entry request.""" return await request.handle_api(req.state.dishka_container) - -@entry_router.delete("/delete", error_map=error_map) +@entry_router.delete( + "/delete", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def delete(request: DeleteRequest, req: Request) -> LDAPResult: """LDAP DELETE entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.post("/delete_many", error_map=error_map) +@entry_router.post( + "/delete_many", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def delete_many( requests: list[DeleteRequest], req: Request, @@ -135,7 +163,10 @@ async def delete_many( return results -@entry_router.post("/set_primary_group") +@entry_router.post( + "/set_primary_group", + dependencies=[Depends(check_master_db)], +) async def set_primary_group( request: PrimaryGroupRequest, session: FromDishka[AsyncSession], diff --git a/app/api/network/router.py b/app/api/network/router.py index f380672f2..62aaed7b4 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.network.adapters.network import NetworkPolicyFastAPIAdapter +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.policies.network.exceptions import ( LastActivePolicyError, @@ -64,6 +65,7 @@ "", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def add_network_policy( policy: Policy, @@ -97,6 +99,7 @@ async def get_list_network_policies( response_class=RedirectResponse, status_code=status.HTTP_303_SEE_OTHER, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def delete_network_policy( policy_id: int, @@ -114,7 +117,11 @@ async def delete_network_policy( return await adapter.delete(request, policy_id) # type: ignore -@network_router.patch("/{policy_id}", error_map=error_map) +@network_router.patch( + "/{policy_id}", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def switch_network_policy( policy_id: int, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -133,7 +140,11 @@ async def switch_network_policy( return await adapter.switch_network_policy(policy_id) -@network_router.put("", error_map=error_map) +@network_router.put( + "", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def update_network_policy( request: PolicyUpdate, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -150,7 +161,11 @@ async def update_network_policy( return await adapter.update(request) -@network_router.post("/swap", error_map=error_map) +@network_router.post( + "/swap", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def swap_network_policy( swap: SwapRequest, adapter: FromDishka[NetworkPolicyFastAPIAdapter], diff --git a/app/api/password_policy/password_ban_word_router.py b/app/api/password_policy/password_ban_word_router.py index a0c06a04e..2ebae09d7 100644 --- a/app/api/password_policy/password_ban_word_router.py +++ b/app/api/password_policy/password_ban_word_router.py @@ -13,6 +13,7 @@ from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordBanWordsFastAPIAdapter from api.password_policy.error_utils import error_map +from api.utils import check_master_db password_ban_word_router = ErrorAwareRouter( prefix="/password_ban_word", @@ -26,6 +27,7 @@ "/upload_txt", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(check_master_db)], ) async def upload_ban_words_txt( file: UploadFile, diff --git a/app/api/password_policy/password_policy_router.py b/app/api/password_policy/password_policy_router.py index 812777ecd..0ea261956 100644 --- a/app/api/password_policy/password_policy_router.py +++ b/app/api/password_policy/password_policy_router.py @@ -13,6 +13,7 @@ from api.password_policy.adapter import PasswordPolicyFastAPIAdapter from api.password_policy.error_utils import error_map from api.password_policy.schemas import PasswordPolicySchema +from api.utils import check_master_db from ldap_protocol.utils.const import GRANT_DN_STRING from .schemas import PriorityT @@ -51,7 +52,11 @@ async def get_password_policy_by_dir_path_dn( return await adapter.get_password_policy_by_dir_path_dn(path_dn) -@password_policy_router.put("/{id_}", error_map=error_map) +@password_policy_router.put( + "/{id_}", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def update( id_: int, policy: PasswordPolicySchema[PriorityT], @@ -61,7 +66,11 @@ async def update( await adapter.update(id_, policy) -@password_policy_router.put("/reset/domain_policy", error_map=error_map) +@password_policy_router.put( + "/reset/domain_policy", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def reset_domain_policy_to_default_config( adapter: FromDishka[PasswordPolicyFastAPIAdapter], ) -> None: diff --git a/app/api/password_policy/user_password_history_router.py b/app/api/password_policy/user_password_history_router.py index 2285c3cdd..7478d38f4 100644 --- a/app/api/password_policy/user_password_history_router.py +++ b/app/api/password_policy/user_password_history_router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.password_policy.adapter import UserPasswordHistoryResetFastAPIAdapter +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.identity.exceptions import ( AuthorizationError, @@ -39,7 +40,7 @@ user_password_history_router = ErrorAwareRouter( prefix="/user/password_history", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(check_master_db)], tags=["User Password history"], route_class=DishkaErrorAwareRoute, ) diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index b1ebe86fb..63a059627 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -8,7 +8,7 @@ from typing import Annotated from dishka import FromDishka -from fastapi import Body, status +from fastapi import Body, Depends, status from fastapi_error_map.routing import ErrorAwareRouter from fastapi_error_map.rules import rule @@ -17,6 +17,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import check_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, @@ -67,7 +68,11 @@ async def proxy_request( return await adapter.proxy_request(principal, ip) -@shadow_router.post("/sync/password", error_map=error_map) +@shadow_router.post( + "/sync/password", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def change_password( principal: Annotated[str, Body(embed=True)], new_password: Annotated[str, Body(embed=True)], diff --git a/app/api/utils.py b/app/api/utils.py new file mode 100644 index 000000000..3e803425f --- /dev/null +++ b/app/api/utils.py @@ -0,0 +1,35 @@ +"""Utils with master database check. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dishka import FromDishka +from dishka.integrations.fastapi import inject +from fastapi import HTTPException, status +from loguru import logger +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings + + +@inject +async def check_master_db( + session: FromDishka[AsyncSession], + settings: FromDishka[Settings], +) -> None: + if settings.POSTGRES_RW_MODE == "single": + return + + try: + session.sync_session.set_force_master(True) # type: ignore + await session.execute(text("SELECT 1")) + except Exception as e: + logger.error(f"Master DB check failed: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Master DB is not available", + ) + else: + session.sync_session.set_force_master(False) # type: ignore From 82108723369b4ad4e619df0f963b049dfe608d19 Mon Sep 17 00:00:00 2001 From: Naksen Date: Tue, 27 Jan 2026 18:47:35 +0300 Subject: [PATCH 05/23] test: update get_engine method to use master database engine --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index efe46fd21..bcb573fd4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ import uvloop from alembic import command from alembic.config import Config as AlembicConfig +from database import engines from dishka import ( AsyncContainer, Provider, @@ -362,9 +363,9 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: ) @provide(scope=scope, provides=AsyncEngine) - def get_engine(self, settings: Settings) -> AsyncEngine: + def get_engine(self) -> AsyncEngine: """Get async engine.""" - return settings.engine + return engines["master"] @provide(scope=Scope.APP, provides=async_sessionmaker[AsyncSession]) def get_session_factory( From 80472eba3f3de4d94796f3c07361a7d7d73bf45a Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 16:26:10 +0300 Subject: [PATCH 06/23] fix: exclude ABANDON protocol from master DB availability checks --- app/ldap_protocol/ldap_requests/abandon.py | 1 - app/ldap_protocol/ldap_requests/base.py | 24 +++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index 4f975bc0e..b9569ca0e 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -17,7 +17,6 @@ class AbandonRequest(BaseRequest): """Abandon protocol.""" - RESPONSE_TYPE: ClassVar[type] = type(None) CONTEXT_TYPE: ClassVar[type] = LDAPAbandonRequestContext PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ABANDON message_id: int diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index b6e6ddde4..cee7bde07 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -126,10 +126,11 @@ async def handle_tcp( responses.append(response) yield response except OperationalError: - yield self.RESPONSE_TYPE( - result_code=LDAPCodes.UNAVAILABLE, - errorMessage="Master DB is not available", - ) + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + yield self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ) return if self.PROTOCOL_OP != ProtocolRequests.SEARCH: @@ -185,12 +186,15 @@ async def _handle_api( try: responses = [response async for response in self.handle(ctx=ctx)] except OperationalError: - responses = [ - self.RESPONSE_TYPE( - result_code=LDAPCodes.UNAVAILABLE, - errorMessage="Master DB is not available", - ), - ] + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + responses = [ + self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ), + ] + else: + responses = [] if settings.DEBUG: for response in responses: From a59190764981a456aaa89a37911a9cbb484d7707 Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 17:01:07 +0300 Subject: [PATCH 07/23] fix: remove unnecessary dependency on check_master_db for audit destinations and MFA callback --- app/api/audit/router.py | 6 +----- app/api/auth/router_mfa.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/app/api/audit/router.py b/app/api/audit/router.py index bf4ef1912..62088f87e 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -74,11 +74,7 @@ async def update_audit_policy( return await audit_adapter.update_policy(policy_id, policy_data) -@audit_router.get( - "/destinations", - error_map=error_map, - dependencies=[Depends(check_master_db)], -) +@audit_router.get("/destinations", error_map=error_map) async def get_audit_destinations( audit_adapter: FromDishka[AuditPoliciesAdapter], ) -> list[AuditDestinationResponse]: diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index c0915de21..350370c27 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -135,7 +135,6 @@ async def get_mfa( name="callback_mfa", include_in_schema=True, error_map=error_map, - dependencies=[Depends(check_master_db)], ) async def callback_mfa( access_token: Annotated[ From b53d1fd53caeee487350f91519e1a49154a5c9cc Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 17:02:33 +0300 Subject: [PATCH 08/23] fix: handle OperationalError specifically in master DB check --- app/api/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/api/utils.py b/app/api/utils.py index 3e803425f..3cd3232a5 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -9,6 +9,7 @@ from fastapi import HTTPException, status from loguru import logger from sqlalchemy import text +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from config import Settings @@ -25,11 +26,11 @@ async def check_master_db( try: session.sync_session.set_force_master(True) # type: ignore await session.execute(text("SELECT 1")) - except Exception as e: + except OperationalError as e: logger.error(f"Master DB check failed: {e}") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Master DB is not available", ) else: - session.sync_session.set_force_master(False) # type: ignore + session.sync_session.set_force_master(False) # type: ignore From 26a610181e1e22f9f54da3e87e50502845818b1f Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 17:03:34 +0300 Subject: [PATCH 09/23] refactor: format --- app/api/main/router.py | 7 ++++++- app/ldap_protocol/ldap_requests/bind.py | 4 +++- app/ldap_protocol/session_storage/repository.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/app/api/main/router.py b/app/api/main/router.py index d3bdd1d2f..210fa7900 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -132,7 +132,12 @@ async def modify_dn_many( return results -@entry_router.put("/rename", error_map=error_map) + +@entry_router.put( + "/rename", + error_map=error_map, + dependencies=[Depends(check_master_db)], +) async def rename(request: RenameRequest, req: Request) -> LDAPResult: """LDAP rename entry request.""" return await request.handle_api(req.state.dishka_container) diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index 7c1a1f150..7f4af0e5c 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -219,7 +219,9 @@ async def handle( await ctx.ldap_session.set_user(user) with contextlib.suppress(OperationalError): await set_user_logon_attrs( - user, ctx.session, ctx.settings.TIMEZONE, + user, + ctx.session, + ctx.settings.TIMEZONE, ) server_sasl_creds = None diff --git a/app/ldap_protocol/session_storage/repository.py b/app/ldap_protocol/session_storage/repository.py index fb9e860c5..2e73dbc2d 100644 --- a/app/ldap_protocol/session_storage/repository.py +++ b/app/ldap_protocol/session_storage/repository.py @@ -91,7 +91,9 @@ async def create_session_key( ) with contextlib.suppress(OperationalError): await set_user_logon_attrs( - user, self.session, self.settings.TIMEZONE, + user, + self.session, + self.settings.TIMEZONE, ) return key From c5eb4b087fa988f9c9e1f3f06a8d4b5c8691b1be Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 17:45:14 +0300 Subject: [PATCH 10/23] fix: update POSTGRES_RW_MODE to use 'replication' instead of 'master_replica' --- app/config.py | 2 +- app/database.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/config.py b/app/config.py index 1e406f993..235a3d4ad 100644 --- a/app/config.py +++ b/app/config.py @@ -48,7 +48,7 @@ class Settings(BaseModel): TCP_PACKET_SIZE: int = 1024 COROUTINES_NUM_PER_CLIENT: int = 3 - POSTGRES_RW_MODE: Literal["single", "master_replica"] = "single" + POSTGRES_RW_MODE: Literal["single", "replication"] = "single" POSTGRES_SCHEMA: ClassVar[str] = "postgresql+psycopg" POSTGRES_REPLICA_DB: str = "" diff --git a/app/database.py b/app/database.py index 4f23a243f..7d0ffaf96 100644 --- a/app/database.py +++ b/app/database.py @@ -30,7 +30,7 @@ connect_args={"connect_timeout": settings.POSTGRES_CONNECT_TIMEOUT}, ), } -if settings.POSTGRES_RW_MODE == "master_replica": +if settings.POSTGRES_RW_MODE == "replication": engines["replica"] = create_async_engine( str(settings.REPLICA_POSTGRES_URI), pool_size=settings.INSTANCE_DB_POOL_SIZE, From 032f2ab8a62abca07e9a3337e3e23163a4b0931b Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 18:32:11 +0300 Subject: [PATCH 11/23] refactor: implement async engine management with EngineRegistry and update session handling --- app/config.py | 37 ++++++++++++++++++++ app/database.py | 79 +++++++++++++++++++------------------------ app/ioc.py | 16 +++++++-- app/multidirectory.py | 3 +- tests/conftest.py | 5 ++- 5 files changed, 87 insertions(+), 53 deletions(-) diff --git a/app/config.py b/app/config.py index 235a3d4ad..d813dd16a 100644 --- a/app/config.py +++ b/app/config.py @@ -22,6 +22,7 @@ computed_field, field_validator, ) +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine def _get_vendor_version() -> str: @@ -118,6 +119,42 @@ def REPLICA_POSTGRES_URI(self) -> PostgresDsn: # noqa f"{self.POSTGRES_REPLICA_DB}", ) + @cached_property + def engine(self) -> AsyncEngine: + """Get engine.""" + return create_async_engine( + str(self.POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="master", + connect_args={"connect_timeout": self.POSTGRES_CONNECT_TIMEOUT}, + ) + + @cached_property + def replica_engine(self) -> AsyncEngine | None: + if self.POSTGRES_RW_MODE != "replication": + return None + + return create_async_engine( + str(self.REPLICA_POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="replica", + connect_args={ + "connect_timeout": self.POSTGRES_REPLICA_CONNECT_TIMEOUT, + }, + ) + VENDOR_NAME: ClassVar[str] = "MultiFactor" VENDOR_VERSION: str = Field( default_factory=_get_vendor_version, diff --git a/app/database.py b/app/database.py index 7d0ffaf96..98e02c857 100644 --- a/app/database.py +++ b/app/database.py @@ -6,70 +6,59 @@ from typing import Any, Sequence -from loguru import logger from sqlalchemy import Delete, Insert, Update, exc as sa_exc from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import Session -from config import Settings - -settings = Settings.from_os() - -engines = { - "master": create_async_engine( - str(settings.POSTGRES_URI), - pool_size=settings.INSTANCE_DB_POOL_SIZE, - max_overflow=settings.INSTANCE_DB_POOL_OVERFLOW, - pool_timeout=settings.INSTANCE_DB_POOL_TIMEOUT, - pool_recycle=settings.INSTANCE_DB_POOL_RECYCLE, - pool_pre_ping=False, - future=True, - echo=False, - logging_name="master", - connect_args={"connect_timeout": settings.POSTGRES_CONNECT_TIMEOUT}, - ), -} -if settings.POSTGRES_RW_MODE == "replication": - engines["replica"] = create_async_engine( - str(settings.REPLICA_POSTGRES_URI), - pool_size=settings.INSTANCE_DB_POOL_SIZE, - max_overflow=settings.INSTANCE_DB_POOL_OVERFLOW, - pool_timeout=settings.INSTANCE_DB_POOL_TIMEOUT, - pool_recycle=settings.INSTANCE_DB_POOL_RECYCLE, - pool_pre_ping=False, - future=True, - echo=False, - logging_name="replica", - connect_args={ - "connect_timeout": settings.POSTGRES_REPLICA_CONNECT_TIMEOUT, - }, - ) + +class EngineRegistry: + _master_engine: AsyncEngine + _replica_engine: AsyncEngine | None + + def __init__( + self, + master_engine: AsyncEngine, + replica_engine: AsyncEngine | None, + ) -> None: + self._master_engine = master_engine + self._replica_engine = replica_engine + + def get_master_engine(self) -> AsyncEngine: + return self._master_engine + + def get_replica_engine(self) -> AsyncEngine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine + + def get_sync_master_engine(self) -> Engine: + return self._master_engine.sync_engine + + def get_sync_replica_engine(self) -> Engine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine.sync_engine class RoutingSession(Session): _force_master: bool = False @property - def force_master(self) -> bool: - return self._force_master + def engine_registry(self) -> EngineRegistry: + return self.info["engine_registry"] def set_force_master(self, value: bool) -> None: self._force_master = value def get_bind(self, mapper=None, clause=None) -> Engine: # type: ignore # noqa: ARG002 - logger.critical("-- CALL RoutingSession.get_bind --") - if isinstance(clause, Update | Insert | Delete): - logger.critical("MASTER") - return engines["master"].sync_engine + return self.engine_registry.get_sync_master_engine() if self._force_master or self._flushing: - logger.critical("MASTER") - return engines["master"].sync_engine + return self.engine_registry.get_sync_master_engine() else: - logger.critical("REPLICA") - return engines["replica"].sync_engine + return self.engine_registry.get_sync_replica_engine() def flush(self, objects: Sequence[Any] | None = None) -> None: if self._flushing: diff --git a/app/ioc.py b/app/ioc.py index 2a14c7112..ca05159a4 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -8,7 +8,7 @@ import httpx import redis.asyncio as redis -from database import RoutingSession, engines +from database import EngineRegistry, RoutingSession from dishka import Provider, Scope, from_context, provide from fastapi import Request from loguru import logger @@ -162,21 +162,30 @@ class MainProvider(Provider): scope = Scope.APP settings = from_context(provides=Settings, scope=Scope.APP) + @provide(scope=Scope.APP) + def get_engine_registry(self, settings: Settings) -> EngineRegistry: + return EngineRegistry( + master_engine=settings.engine, + replica_engine=settings.replica_engine, + ) + @provide(scope=Scope.APP) def get_session_factory( self, settings: Settings, + engine_registry: EngineRegistry, ) -> async_sessionmaker[AsyncSession]: """Create session factory.""" if settings.POSTGRES_RW_MODE == "single": return async_sessionmaker( - bind=engines["master"], + bind=engine_registry.get_master_engine(), expire_on_commit=False, ) return async_sessionmaker( sync_session_class=RoutingSession, expire_on_commit=False, + info={"engine_registry": engine_registry}, ) @provide(scope=Scope.REQUEST) @@ -899,8 +908,9 @@ def get_session_factory( @provide(scope=Scope.APP) async def get_conn_factory( self, + engine_registry: EngineRegistry, ) -> AsyncIterator[AsyncConnection]: """Create session factory.""" - engine = engines["master"] + engine = engine_registry.get_master_engine() async with engine.connect() as connection: yield connection diff --git a/app/multidirectory.py b/app/multidirectory.py index e6c048c34..c1cdd4ee7 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -13,7 +13,6 @@ import uvicorn import uvloop from alembic.config import Config, command -from database import engines from dishka import Scope, make_async_container from dishka.integrations.fastapi import setup_dishka from fastapi import FastAPI @@ -139,7 +138,7 @@ def handle(self, statistics: AlchemyStatistics) -> None: app.add_middleware( SQLAlchemyMonitor, - engine=engines["master"], + engine=settings.engine, actions=[JsonPrintStatistics()], ) diff --git a/tests/conftest.py b/tests/conftest.py index bcb573fd4..efe46fd21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ import uvloop from alembic import command from alembic.config import Config as AlembicConfig -from database import engines from dishka import ( AsyncContainer, Provider, @@ -363,9 +362,9 @@ def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: ) @provide(scope=scope, provides=AsyncEngine) - def get_engine(self) -> AsyncEngine: + def get_engine(self, settings: Settings) -> AsyncEngine: """Get async engine.""" - return engines["master"] + return settings.engine @provide(scope=Scope.APP, provides=async_sessionmaker[AsyncSession]) def get_session_factory( From db35ba2755cd8431275375b34eae980aff76c3e2 Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 18:41:01 +0300 Subject: [PATCH 12/23] refactor: database module name --- app/{database.py => db_routing.py} | 2 +- app/ioc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename app/{database.py => db_routing.py} (97%) diff --git a/app/database.py b/app/db_routing.py similarity index 97% rename from app/database.py rename to app/db_routing.py index 98e02c857..0f9580e63 100644 --- a/app/database.py +++ b/app/db_routing.py @@ -1,4 +1,4 @@ -"""Database configuration and routing session. +"""Engine registry and routing session. Copyright (c) 2026 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE diff --git a/app/ioc.py b/app/ioc.py index ca05159a4..b1e4bd31a 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -8,7 +8,7 @@ import httpx import redis.asyncio as redis -from database import EngineRegistry, RoutingSession +from db_routing import EngineRegistry, RoutingSession from dishka import Provider, Scope, from_context, provide from fastapi import Request from loguru import logger From fa5805266a8ef5b74a8cb9ef4b1c9efbe9adebbc Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 18:45:47 +0300 Subject: [PATCH 13/23] refactor: rename check_master_db to require_master_db and update dependencies across routers --- app/api/audit/router.py | 10 +++++----- app/api/auth/router_auth.py | 6 +++--- app/api/auth/router_mfa.py | 8 ++++---- app/api/ldap_schema/attribute_type_router.py | 8 ++++---- app/api/ldap_schema/entity_type_router.py | 8 ++++---- app/api/ldap_schema/object_class_router.py | 8 ++++---- app/api/main/dns_router.py | 4 ++-- app/api/main/krb5_router.py | 14 +++++++------- app/api/main/router.py | 18 +++++++++--------- app/api/network/router.py | 12 ++++++------ .../password_ban_word_router.py | 4 ++-- .../password_policy/password_policy_router.py | 6 +++--- .../user_password_history_router.py | 4 ++-- app/api/shadow/router.py | 4 ++-- app/api/utils.py | 2 +- 15 files changed, 58 insertions(+), 58 deletions(-) diff --git a/app/api/audit/router.py b/app/api/audit/router.py index 62088f87e..6209d0740 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -15,7 +15,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.audit.exception import ( AuditAlreadyExistsError, @@ -63,7 +63,7 @@ async def get_audit_policies( @audit_router.put( "/policy/{policy_id}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def update_audit_policy( policy_id: int, @@ -86,7 +86,7 @@ async def get_audit_destinations( "/destination", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def create_audit_destination( destination_data: AuditDestinationSchemaRequest, @@ -99,7 +99,7 @@ async def create_audit_destination( @audit_router.delete( "/destination/{destination_id}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_audit_destination( destination_id: int, @@ -112,7 +112,7 @@ async def delete_audit_destination( @audit_router.put( "/destination/{destination_id}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def update_audit_destination( destination_id: int, diff --git a/app/api/auth/router_auth.py b/app/api/auth/router_auth.py index 75dbb019e..a5c98911f 100644 --- a/app/api/auth/router_auth.py +++ b/app/api/auth/router_auth.py @@ -19,7 +19,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( MFAAPIError, @@ -187,7 +187,7 @@ async def logout( @auth_router.patch( "/user/password", status_code=200, - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def password_reset( @@ -230,7 +230,7 @@ async def check_setup( status_code=status.HTTP_200_OK, responses={423: {"detail": "Locked"}}, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def first_setup( request: SetupRequest, diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index 350370c27..8e275b242 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -24,7 +24,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( ForbiddenError, @@ -82,7 +82,7 @@ @mfa_router.post( "/setup", status_code=status.HTTP_201_CREATED, - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def setup_mfa( @@ -101,7 +101,7 @@ async def setup_mfa( @mfa_router.delete( "/keys", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def remove_mfa( @@ -114,7 +114,7 @@ async def remove_mfa( @mfa_router.post( "/get", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def get_mfa( diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index 503f3654e..a75a1826a 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -16,7 +16,7 @@ AttributeTypeSchema, AttributeTypeUpdateSchema, ) -from api.utils import check_master_db +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,7 +24,7 @@ "/attribute_type", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def create_one_attribute_type( request_data: AttributeTypeSchema[None], @@ -61,7 +61,7 @@ async def get_list_attribute_types_with_pagination( @ldap_schema_router.patch( "/attribute_type/{attribute_type_name}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_one_attribute_type( attribute_type_name: str, @@ -75,7 +75,7 @@ async def modify_one_attribute_type( @ldap_schema_router.post( "/attribute_types/delete", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_bulk_attribute_types( attribute_types_names: LimitedListType, diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index c4bf1d85a..129230b8e 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -17,7 +17,7 @@ EntityTypeSchema, EntityTypeUpdateSchema, ) -from api.utils import check_master_db +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -25,7 +25,7 @@ "/entity_type", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def create_one_entity_type( request_data: EntityTypeSchema[None], @@ -68,7 +68,7 @@ async def get_entity_type_attributes( @ldap_schema_router.patch( "/entity_type/{entity_type_name}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_one_entity_type( entity_type_name: str, @@ -82,7 +82,7 @@ async def modify_one_entity_type( @ldap_schema_router.post( "/entity_type/delete", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_bulk_entity_types( entity_type_names: LimitedListType, diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index c4bc8d44a..a6baced69 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -17,7 +17,7 @@ ObjectClassSchema, ObjectClassUpdateSchema, ) -from api.utils import check_master_db +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -25,7 +25,7 @@ "/object_class", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def create_one_object_class( request_data: ObjectClassSchema[None], @@ -59,7 +59,7 @@ async def get_list_object_classes_with_pagination( @ldap_schema_router.patch( "/object_class/{object_class_name}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_one_object_class( object_class_name: str, @@ -73,7 +73,7 @@ async def modify_one_object_class( @ldap_schema_router.post( "/object_class/delete", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_bulk_object_classes( object_classes_names: LimitedListType, diff --git a/app/api/main/dns_router.py b/app/api/main/dns_router.py index 099187337..bf3e83e40 100644 --- a/app/api/main/dns_router.py +++ b/app/api/main/dns_router.py @@ -29,7 +29,7 @@ DNSServiceZoneDeleteRequest, DNSServiceZoneUpdateRequest, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dns import ( DNSForwardServerStatus, @@ -143,7 +143,7 @@ async def get_dns_status( @dns_router.post( "/setup", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def setup_dns( data: DNSServiceSetupRequest, diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index a52858eb3..9ed36515c 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -24,7 +24,7 @@ ) from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.main.schema import KerberosSetupRequest -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import KerberosState @@ -83,7 +83,7 @@ "/setup/tree", response_class=Response, error_map=error_map, - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], ) async def setup_krb_catalogue( mail: Annotated[EmailStr, Body()], @@ -111,7 +111,7 @@ async def setup_krb_catalogue( "/setup", response_class=Response, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def setup_kdc( data: KerberosSetupRequest, @@ -179,7 +179,7 @@ async def get_krb_status( @krb5_router.post( "/principal/add", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def add_principal( @@ -199,7 +199,7 @@ async def add_principal( @krb5_router.patch( "/principal/rename", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def rename_principal( @@ -223,7 +223,7 @@ async def rename_principal( @krb5_router.patch( "/principal/reset", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def reset_principal_pw( @@ -244,7 +244,7 @@ async def reset_principal_pw( @krb5_router.delete( "/principal/delete", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def delete_principal( diff --git a/app/api/main/router.py b/app/api/main/router.py index 210fa7900..e8815fad0 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -16,7 +16,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.custom_requests.rename import RenameRequest from ldap_protocol.identity.exceptions import UnauthorizedError @@ -73,7 +73,7 @@ async def search(request: SearchRequest, req: Request) -> SearchResponse: @entry_router.post( "/add", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def add(request: AddRequest, req: Request) -> LDAPResult: """LDAP ADD entry request.""" @@ -83,7 +83,7 @@ async def add(request: AddRequest, req: Request) -> LDAPResult: @entry_router.patch( "/update", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify(request: ModifyRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry request.""" @@ -93,7 +93,7 @@ async def modify(request: ModifyRequest, req: Request) -> LDAPResult: @entry_router.patch( "/update_many", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_many( requests: list[ModifyRequest], @@ -109,7 +109,7 @@ async def modify_many( @entry_router.put( "/update/dn", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_dn(request: ModifyDNRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry DN request.""" @@ -119,7 +119,7 @@ async def modify_dn(request: ModifyDNRequest, req: Request) -> LDAPResult: @entry_router.post( "/update_many/dn", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def modify_dn_many( requests: list[ModifyDNRequest], @@ -145,7 +145,7 @@ async def rename(request: RenameRequest, req: Request) -> LDAPResult: @entry_router.delete( "/delete", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete(request: DeleteRequest, req: Request) -> LDAPResult: """LDAP DELETE entry request.""" @@ -155,7 +155,7 @@ async def delete(request: DeleteRequest, req: Request) -> LDAPResult: @entry_router.post( "/delete_many", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_many( requests: list[DeleteRequest], @@ -170,7 +170,7 @@ async def delete_many( @entry_router.post( "/set_primary_group", - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def set_primary_group( request: PrimaryGroupRequest, diff --git a/app/api/network/router.py b/app/api/network/router.py index 62aaed7b4..71c87cb5b 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -18,7 +18,7 @@ DomainErrorTranslator, ) from api.network.adapters.network import NetworkPolicyFastAPIAdapter -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.network.exceptions import ( LastActivePolicyError, @@ -65,7 +65,7 @@ "", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def add_network_policy( policy: Policy, @@ -99,7 +99,7 @@ async def get_list_network_policies( response_class=RedirectResponse, status_code=status.HTTP_303_SEE_OTHER, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def delete_network_policy( policy_id: int, @@ -120,7 +120,7 @@ async def delete_network_policy( @network_router.patch( "/{policy_id}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def switch_network_policy( policy_id: int, @@ -143,7 +143,7 @@ async def switch_network_policy( @network_router.put( "", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def update_network_policy( request: PolicyUpdate, @@ -164,7 +164,7 @@ async def update_network_policy( @network_router.post( "/swap", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def swap_network_policy( swap: SwapRequest, diff --git a/app/api/password_policy/password_ban_word_router.py b/app/api/password_policy/password_ban_word_router.py index 2ebae09d7..5185124dc 100644 --- a/app/api/password_policy/password_ban_word_router.py +++ b/app/api/password_policy/password_ban_word_router.py @@ -13,7 +13,7 @@ from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordBanWordsFastAPIAdapter from api.password_policy.error_utils import error_map -from api.utils import check_master_db +from api.utils import require_master_db password_ban_word_router = ErrorAwareRouter( prefix="/password_ban_word", @@ -27,7 +27,7 @@ "/upload_txt", status_code=status.HTTP_201_CREATED, error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def upload_ban_words_txt( file: UploadFile, diff --git a/app/api/password_policy/password_policy_router.py b/app/api/password_policy/password_policy_router.py index 0ea261956..36bd206c3 100644 --- a/app/api/password_policy/password_policy_router.py +++ b/app/api/password_policy/password_policy_router.py @@ -13,7 +13,7 @@ from api.password_policy.adapter import PasswordPolicyFastAPIAdapter from api.password_policy.error_utils import error_map from api.password_policy.schemas import PasswordPolicySchema -from api.utils import check_master_db +from api.utils import require_master_db from ldap_protocol.utils.const import GRANT_DN_STRING from .schemas import PriorityT @@ -55,7 +55,7 @@ async def get_password_policy_by_dir_path_dn( @password_policy_router.put( "/{id_}", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def update( id_: int, @@ -69,7 +69,7 @@ async def update( @password_policy_router.put( "/reset/domain_policy", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def reset_domain_policy_to_default_config( adapter: FromDishka[PasswordPolicyFastAPIAdapter], diff --git a/app/api/password_policy/user_password_history_router.py b/app/api/password_policy/user_password_history_router.py index 7478d38f4..9af233c12 100644 --- a/app/api/password_policy/user_password_history_router.py +++ b/app/api/password_policy/user_password_history_router.py @@ -18,7 +18,7 @@ DomainErrorTranslator, ) from api.password_policy.adapter import UserPasswordHistoryResetFastAPIAdapter -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.identity.exceptions import ( AuthorizationError, @@ -40,7 +40,7 @@ user_password_history_router = ErrorAwareRouter( prefix="/user/password_history", - dependencies=[Depends(verify_auth), Depends(check_master_db)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], tags=["User Password history"], route_class=DishkaErrorAwareRoute, ) diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index 63a059627..b708babb0 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -17,7 +17,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) -from api.utils import check_master_db +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, @@ -71,7 +71,7 @@ async def proxy_request( @shadow_router.post( "/sync/password", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def change_password( principal: Annotated[str, Body(embed=True)], diff --git a/app/api/utils.py b/app/api/utils.py index 3cd3232a5..aa3b2e289 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -16,7 +16,7 @@ @inject -async def check_master_db( +async def require_master_db( session: FromDishka[AsyncSession], settings: FromDishka[Settings], ) -> None: From d04426a48a2f52bd9f2907e0b3aaaa1f9e770c2e Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 18:54:11 +0300 Subject: [PATCH 14/23] fix: handle OperationalError by initializing responses to an empty list --- app/ldap_protocol/ldap_requests/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index cee7bde07..63667f034 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -186,15 +186,14 @@ async def _handle_api( try: responses = [response async for response in self.handle(ctx=ctx)] except OperationalError: + responses = [] if self.PROTOCOL_OP != ProtocolRequests.ABANDON: - responses = [ + responses.append( self.RESPONSE_TYPE( result_code=LDAPCodes.UNAVAILABLE, errorMessage="Master DB is not available", ), - ] - else: - responses = [] + ) if settings.DEBUG: for response in responses: From 9a9010c29195167edc49a4a6b6d5fe77893758db Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 19:04:26 +0300 Subject: [PATCH 15/23] refactor: update _add_app_sqlalchemy_debugger to accept settings parameter --- app/multidirectory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/multidirectory.py b/app/multidirectory.py index c1cdd4ee7..73f82b010 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -118,7 +118,7 @@ def _create_shadow_app(settings: Settings) -> FastAPI: return app -def _add_app_sqlalchemy_debugger(app: FastAPI) -> None: +def _add_app_sqlalchemy_debugger(app: FastAPI, settings: Settings) -> None: try: import json from dataclasses import asdict @@ -160,7 +160,7 @@ def create_prod_app( ) if settings.ENABLE_SQLALCHEMY_LOGGING: - _add_app_sqlalchemy_debugger(app) + _add_app_sqlalchemy_debugger(app, settings) setup_dishka(container, app) return app From f55376d3622f32e862710d179e438020792300b6 Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 19:05:05 +0300 Subject: [PATCH 16/23] refactor: format --- app/multidirectory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/multidirectory.py b/app/multidirectory.py index 73f82b010..22a19259d 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -149,7 +149,6 @@ def create_prod_app( ) -> FastAPI: """Create production app with container.""" settings = settings or Settings.from_os() - app = factory(settings) container = make_async_container( MainProvider(), From e2fb0b24679211c67a1d9ecdbe4b6fabc1324f0b Mon Sep 17 00:00:00 2001 From: Naksen Date: Wed, 28 Jan 2026 20:24:10 +0300 Subject: [PATCH 17/23] fix: add _force_master flag to execute ops --- app/db_routing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/db_routing.py b/app/db_routing.py index 0f9580e63..1e305b361 100644 --- a/app/db_routing.py +++ b/app/db_routing.py @@ -53,6 +53,7 @@ def set_force_master(self, value: bool) -> None: def get_bind(self, mapper=None, clause=None) -> Engine: # type: ignore # noqa: ARG002 if isinstance(clause, Update | Insert | Delete): + self._force_master = True return self.engine_registry.get_sync_master_engine() if self._force_master or self._flushing: From 6ea2c376851995250ce9a22bc308654a54f4fe9e Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 11:38:44 +0300 Subject: [PATCH 18/23] refactor: replace string literals with PostgresRWModeType enum for better type safety --- app/api/utils.py | 3 ++- app/config.py | 4 +++- app/db_routing.py | 19 +++++++++++++++++-- app/enums.py | 7 +++++++ app/ioc.py | 11 ++++------- 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/app/api/utils.py b/app/api/utils.py index aa3b2e289..e58ac73bf 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from config import Settings +from enums import PostgresRWModeType @inject @@ -20,7 +21,7 @@ async def require_master_db( session: FromDishka[AsyncSession], settings: FromDishka[Settings], ) -> None: - if settings.POSTGRES_RW_MODE == "single": + if settings.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: return try: diff --git a/app/config.py b/app/config.py index d813dd16a..1b22bf474 100644 --- a/app/config.py +++ b/app/config.py @@ -24,6 +24,8 @@ ) from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from enums import PostgresRWModeType + def _get_vendor_version() -> str: with open("/pyproject.toml", "rb") as f: @@ -49,7 +51,7 @@ class Settings(BaseModel): TCP_PACKET_SIZE: int = 1024 COROUTINES_NUM_PER_CLIENT: int = 3 - POSTGRES_RW_MODE: Literal["single", "replication"] = "single" + POSTGRES_RW_MODE: PostgresRWModeType = PostgresRWModeType.SINGLE POSTGRES_SCHEMA: ClassVar[str] = "postgresql+psycopg" POSTGRES_REPLICA_DB: str = "" diff --git a/app/db_routing.py b/app/db_routing.py index 1e305b361..f19ff6e1e 100644 --- a/app/db_routing.py +++ b/app/db_routing.py @@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import Session +from enums import PostgresRWModeType + class EngineRegistry: _master_engine: AsyncEngine @@ -46,12 +48,25 @@ class RoutingSession(Session): @property def engine_registry(self) -> EngineRegistry: - return self.info["engine_registry"] + engine_registry = self.info.get("engine_registry") + if engine_registry is None: + raise RuntimeError("Engine registry is not configured") + return engine_registry + + @property + def rw_mode(self) -> PostgresRWModeType: + rw_mode = self.info.get("rw_mode") + if rw_mode is None: + raise RuntimeError("RW mode is not configured") + return rw_mode def set_force_master(self, value: bool) -> None: self._force_master = value - def get_bind(self, mapper=None, clause=None) -> Engine: # type: ignore # noqa: ARG002 + def get_bind(self, mapper=None, *, clause=None, **kw) -> Engine: # type: ignore # noqa: ARG002 + if self.rw_mode == PostgresRWModeType.SINGLE: + return self.engine_registry.get_sync_master_engine() + if isinstance(clause, Update | Insert | Delete): self._force_master = True return self.engine_registry.get_sync_master_engine() diff --git a/app/enums.py b/app/enums.py index 1f6e8f798..b4ef3cde4 100644 --- a/app/enums.py +++ b/app/enums.py @@ -12,6 +12,13 @@ from typing import Iterable, Self +class PostgresRWModeType(StrEnum): + """Postgres read/write mode type.""" + + SINGLE = "single" + REPLICATION = "replication" + + class AceType(IntEnum): """ACE types.""" diff --git a/app/ioc.py b/app/ioc.py index b1e4bd31a..7bb60881c 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -176,16 +176,13 @@ def get_session_factory( engine_registry: EngineRegistry, ) -> async_sessionmaker[AsyncSession]: """Create session factory.""" - if settings.POSTGRES_RW_MODE == "single": - return async_sessionmaker( - bind=engine_registry.get_master_engine(), - expire_on_commit=False, - ) - return async_sessionmaker( sync_session_class=RoutingSession, expire_on_commit=False, - info={"engine_registry": engine_registry}, + info={ + "engine_registry": engine_registry, + "rw_mode": settings.POSTGRES_RW_MODE, + }, ) @provide(scope=Scope.REQUEST) From 676b0b6494d732677acd4397c067b3e82aa7feea Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 11:43:27 +0300 Subject: [PATCH 19/23] fix: update dependencies for rename endpoint to require_master_db --- app/api/main/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/api/main/router.py b/app/api/main/router.py index e8815fad0..174a65afa 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -132,16 +132,16 @@ async def modify_dn_many( return results - @entry_router.put( "/rename", error_map=error_map, - dependencies=[Depends(check_master_db)], + dependencies=[Depends(require_master_db)], ) async def rename(request: RenameRequest, req: Request) -> LDAPResult: """LDAP rename entry request.""" return await request.handle_api(req.state.dishka_container) + @entry_router.delete( "/delete", error_map=error_map, From 514802f6172b779618c721b4aa3e0f241135970b Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 16:49:13 +0300 Subject: [PATCH 20/23] refactor: implement master database check and gateway for PostgreSQL routing --- app/api/utils.py | 21 ++------------ app/ioc.py | 18 ++++++++++++ app/ldap_protocol/master_check_use_case.py | 27 ++++++++++++++++++ app/repo/pg/master_gateway.py | 33 ++++++++++++++++++++++ 4 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 app/ldap_protocol/master_check_use_case.py create mode 100644 app/repo/pg/master_gateway.py diff --git a/app/api/utils.py b/app/api/utils.py index e58ac73bf..5f94d56f6 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -7,31 +7,16 @@ from dishka import FromDishka from dishka.integrations.fastapi import inject from fastapi import HTTPException, status -from loguru import logger -from sqlalchemy import text -from sqlalchemy.exc import OperationalError -from sqlalchemy.ext.asyncio import AsyncSession -from config import Settings -from enums import PostgresRWModeType +from ldap_protocol.master_check_use_case import MasterCheckUseCase @inject async def require_master_db( - session: FromDishka[AsyncSession], - settings: FromDishka[Settings], + master_check_use_case: FromDishka[MasterCheckUseCase], ) -> None: - if settings.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: - return - - try: - session.sync_session.set_force_master(True) # type: ignore - await session.execute(text("SELECT 1")) - except OperationalError as e: - logger.error(f"Master DB check failed: {e}") + if not await master_check_use_case.check_master(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Master DB is not available", ) - else: - session.sync_session.set_force_master(False) # type: ignore diff --git a/app/ioc.py b/app/ioc.py index 7bb60881c..81909c9c3 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -88,6 +88,10 @@ from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.master_check_use_case import ( + MasterCheckUseCase, + MasterGatewayProtocol, +) from ldap_protocol.multifactor import ( Creds, LDAPMultiFactorAPI, @@ -148,6 +152,7 @@ from ldap_protocol.session_storage import RedisSessionStorage, SessionStorage from ldap_protocol.session_storage.repository import SessionRepository from password_utils import PasswordUtils +from repo.pg.master_gateway import PGMasterGateway SessionStorageClient = NewType("SessionStorageClient", redis.Redis) KadminHTTPClient = NewType("KadminHTTPClient", httpx.AsyncClient) @@ -581,6 +586,19 @@ def get_audit_monitor( session_key=session_key, ) + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) + async def get_master_gateway( + self, + session: AsyncSession, + settings: Settings, + ) -> PGMasterGateway: + return PGMasterGateway(session, settings) + + master_check_use_case = provide( + MasterCheckUseCase, + scope=Scope.REQUEST, + ) + identity_provider_gateway = provide( IdentityProviderGateway, scope=Scope.REQUEST, diff --git a/app/ldap_protocol/master_check_use_case.py b/app/ldap_protocol/master_check_use_case.py new file mode 100644 index 000000000..531790337 --- /dev/null +++ b/app/ldap_protocol/master_check_use_case.py @@ -0,0 +1,27 @@ +"""Check Master Use Case. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Protocol + +from abstract_service import AbstractService + + +class MasterGatewayProtocol(Protocol): + """Master DB Gateway Protocol.""" + + async def check_master(self) -> bool: ... + + +class MasterCheckUseCase(AbstractService): + """Check Master Use Case.""" + + _master_gateway: MasterGatewayProtocol + + def __init__(self, master_gateway: MasterGatewayProtocol) -> None: + self._master_gateway = master_gateway + + async def check_master(self) -> bool: + return await self._master_gateway.check_master() diff --git a/app/repo/pg/master_gateway.py b/app/repo/pg/master_gateway.py new file mode 100644 index 000000000..20476c8d3 --- /dev/null +++ b/app/repo/pg/master_gateway.py @@ -0,0 +1,33 @@ +"""Master DB Gateway. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from loguru import logger +from sqlalchemy import text +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from enums import PostgresRWModeType + + +class PGMasterGateway: + def __init__(self, session: AsyncSession, settings: Settings) -> None: + self._session = session + self._settings = settings + + async def check_master(self) -> bool: + if self._settings.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: + return True + + try: + self._session.sync_session.set_force_master(True) # type: ignore + await self._session.execute(text("SELECT 1")) + except OperationalError as e: + logger.error(f"Master DB check failed: {e}") + return False + else: + self._session.sync_session.set_force_master(False) # type: ignore + return True From 3f6e9dae31bc806eda7a592fc56dc00b724f85f2 Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 16:54:59 +0300 Subject: [PATCH 21/23] test: add MasterCheckUseCase and PGMasterGateway to conftest for enhanced testing --- tests/conftest.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index efe46fd21..eddfb7215 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -110,6 +110,10 @@ from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.master_check_use_case import ( + MasterCheckUseCase, + MasterGatewayProtocol, +) from ldap_protocol.multifactor import LDAPMultiFactorAPI, MultifactorAPI from ldap_protocol.permissions_checker import AuthorizationProvider from ldap_protocol.policies.audit.audit_use_case import AuditUseCase @@ -157,6 +161,7 @@ from ldap_protocol.session_storage.repository import SessionRepository from ldap_protocol.utils.queries import get_user from password_utils import PasswordUtils +from repo.pg.master_gateway import PGMasterGateway from tests.constants import TEST_DATA @@ -467,6 +472,19 @@ async def get_redis_for_sessions( with suppress(RuntimeError): await client.aclose() + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) + async def get_master_gateway( + self, + session: AsyncSession, + settings: Settings, + ) -> PGMasterGateway: + return PGMasterGateway(session, settings) + + master_check_use_case = provide( + MasterCheckUseCase, + scope=Scope.REQUEST, + ) + @provide(scope=Scope.APP) async def get_session_storage( self, From d6a86a460ba5c69d9e13c44c93f450aa6a021e31 Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 16:56:47 +0300 Subject: [PATCH 22/23] fix: update replica_engine condition to check for SINGLE mode in PostgreSQL routing --- app/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/config.py b/app/config.py index 1b22bf474..8637bcdd6 100644 --- a/app/config.py +++ b/app/config.py @@ -139,7 +139,7 @@ def engine(self) -> AsyncEngine: @cached_property def replica_engine(self) -> AsyncEngine | None: - if self.POSTGRES_RW_MODE != "replication": + if self.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: return None return create_async_engine( From a2f29d3fc32be6bc0c58022cea196382856926e4 Mon Sep 17 00:00:00 2001 From: Naksen Date: Fri, 6 Feb 2026 17:06:13 +0300 Subject: [PATCH 23/23] fix: enhance MasterCheckUseCase by adding PERMISSIONS attribute and organizing imports --- app/ldap_protocol/master_check_use_case.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app/ldap_protocol/master_check_use_case.py b/app/ldap_protocol/master_check_use_case.py index 531790337..2c010e788 100644 --- a/app/ldap_protocol/master_check_use_case.py +++ b/app/ldap_protocol/master_check_use_case.py @@ -4,9 +4,10 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from typing import Protocol +from typing import ClassVar, Protocol from abstract_service import AbstractService +from enums import AuthorizationRules class MasterGatewayProtocol(Protocol): @@ -25,3 +26,5 @@ def __init__(self, master_gateway: MasterGatewayProtocol) -> None: async def check_master(self) -> bool: return await self._master_gateway.check_master() + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = {}