From 6c4f58e2051d9d77b28998b230293d99178414a9 Mon Sep 17 00:00:00 2001 From: Luke Oliff Date: Sun, 8 Feb 2026 15:20:07 +0000 Subject: [PATCH] feat(proxy): add drop-in proxy middleware for web applications Adds a proxy module that keeps Deepgram API keys server-side while providing scoped JWT auth and REST/WebSocket forwarding. Includes adapters for FastAPI, Flask, and Django with 57 tests covering scopes, JWT, engine, and end-to-end FastAPI integration. --- TODO.md | 31 +++ src/deepgram/proxy/__init__.py | 6 + src/deepgram/proxy/adapters/__init__.py | 1 + src/deepgram/proxy/adapters/django.py | 131 +++++++++ src/deepgram/proxy/adapters/fastapi.py | 98 +++++++ src/deepgram/proxy/adapters/flask.py | 111 ++++++++ src/deepgram/proxy/engine.py | 355 ++++++++++++++++++++++++ src/deepgram/proxy/errors.py | 32 +++ src/deepgram/proxy/jwt.py | 86 ++++++ src/deepgram/proxy/scopes.py | 73 +++++ tests/custom/test_proxy_engine.py | 167 +++++++++++ tests/custom/test_proxy_fastapi.py | 87 ++++++ tests/custom/test_proxy_jwt.py | 92 ++++++ tests/custom/test_proxy_scopes.py | 78 ++++++ 14 files changed, 1348 insertions(+) create mode 100644 TODO.md create mode 100644 src/deepgram/proxy/__init__.py create mode 100644 src/deepgram/proxy/adapters/__init__.py create mode 100644 src/deepgram/proxy/adapters/django.py create mode 100644 src/deepgram/proxy/adapters/fastapi.py create mode 100644 src/deepgram/proxy/adapters/flask.py create mode 100644 src/deepgram/proxy/engine.py create mode 100644 src/deepgram/proxy/errors.py create mode 100644 src/deepgram/proxy/jwt.py create mode 100644 src/deepgram/proxy/scopes.py create mode 100644 tests/custom/test_proxy_engine.py create mode 100644 tests/custom/test_proxy_fastapi.py create mode 100644 tests/custom/test_proxy_jwt.py create mode 100644 tests/custom/test_proxy_scopes.py diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..295b786c --- /dev/null +++ b/TODO.md @@ -0,0 +1,31 @@ +# Proxy Feature — Required Changes to Existing Files + +These changes are needed in Fern-generated or config files that aren't modified +by the proxy implementation itself. + +## pyproject.toml + +Add PyJWT as an optional dependency and create the `proxy` extra: + +```toml +[tool.poetry.dependencies] +# ... existing deps ... +PyJWT = {version = ">=2.0.0", optional = true} + +[tool.poetry.extras] +proxy = ["PyJWT"] +``` + +This allows users to install with: +``` +pip install "deepgram-sdk[proxy]" +``` + +## Optional runtime dependencies (not in pyproject.toml) + +These are NOT added as project dependencies — users install them directly: + +- **websockets** — required for WebSocket proxying +- **fastapi** — for the FastAPI adapter +- **flask** / **flask-sock** — for the Flask adapter (flask-sock for WS) +- **django** / **channels** — for the Django adapter (channels for WS) diff --git a/src/deepgram/proxy/__init__.py b/src/deepgram/proxy/__init__.py new file mode 100644 index 00000000..da61f7c3 --- /dev/null +++ b/src/deepgram/proxy/__init__.py @@ -0,0 +1,6 @@ +"""Deepgram Proxy — drop-in proxy middleware for web applications.""" + +from .engine import DeepgramProxy +from .scopes import Scope + +__all__ = ["DeepgramProxy", "Scope"] diff --git a/src/deepgram/proxy/adapters/__init__.py b/src/deepgram/proxy/adapters/__init__.py new file mode 100644 index 00000000..a91a16a8 --- /dev/null +++ b/src/deepgram/proxy/adapters/__init__.py @@ -0,0 +1 @@ +"""Framework adapters for the Deepgram proxy.""" diff --git a/src/deepgram/proxy/adapters/django.py b/src/deepgram/proxy/adapters/django.py new file mode 100644 index 00000000..051cde14 --- /dev/null +++ b/src/deepgram/proxy/adapters/django.py @@ -0,0 +1,131 @@ +"""Django adapter for the Deepgram proxy.""" + +import asyncio +from typing import TYPE_CHECKING, Any, List + +from ..errors import ProxyError + +if TYPE_CHECKING: + from ..engine import DeepgramProxy + + +def deepgram_proxy_urls(proxy: "DeepgramProxy") -> List[Any]: + """Create Django URL patterns that proxy requests to Deepgram. + + REST views are CSRF-exempt. Optional WebSocket support requires + Django Channels (``pip install channels``). + + Usage:: + + from django.urls import path, include + from deepgram.proxy import DeepgramProxy + from deepgram.proxy.adapters.django import deepgram_proxy_urls + + proxy = DeepgramProxy(api_key="dg-xxx") + urlpatterns = [path("deepgram/", include(deepgram_proxy_urls(proxy)))] + + Returns: + List of URL patterns. If Django Channels is installed, the list also + has a ``websocket_consumer`` attribute containing the ASGI consumer class. + """ + from django.http import HttpRequest, HttpResponse + from django.urls import re_path + from django.views.decorators.csrf import csrf_exempt + + @csrf_exempt + def proxy_rest(request: HttpRequest, path: str) -> HttpResponse: + full_path = f"/{path}" + authorization = request.headers.get("Authorization") + + try: + scopes = proxy.authenticate(authorization) + proxy.authorize(full_path, scopes) + except ProxyError as exc: + return HttpResponse(exc.message, status=exc.status_code) + + headers = dict(request.headers) + body = request.body + query_string = request.META.get("QUERY_STRING", "") + + try: + status, resp_headers, resp_body = proxy.forward_rest_sync( + method=request.method, + path=full_path, + headers=headers, + query_string=query_string, + body=body, + ) + except ProxyError as exc: + return HttpResponse(exc.message, status=exc.status_code) + + response = HttpResponse(resp_body, status=status) + for k, v in resp_headers.items(): + if k.lower() not in ("content-length", "content-encoding"): + response[k] = v + return response + + patterns: Any = [ + re_path(r"^(?P.+)$", proxy_rest, name="deepgram_proxy_rest"), + ] + + # Optional WebSocket support via Django Channels + try: + from channels.generic.websocket import AsyncWebsocketConsumer + + class DeepgramProxyConsumer(AsyncWebsocketConsumer): + """ASGI WebSocket consumer for Deepgram proxy.""" + + async def connect(self) -> None: + self._path = "/" + self.scope.get("path", "").lstrip("/") + # Remove the URL prefix to get the API path + query_string = self.scope.get("query_string", b"").decode("utf-8") + subprotocol = None + for header_name, header_value in self.scope.get("headers", []): + if header_name == b"sec-websocket-protocol": + subprotocol = header_value.decode("utf-8") + break + + self._subprotocol = subprotocol + self._query_string = query_string + self._message_queue: asyncio.Queue = asyncio.Queue() + + await self.accept(subprotocol=subprotocol) + + # Start the relay + asyncio.ensure_future(self._relay()) + + async def _relay(self) -> None: + async def client_receive(): + msg = await self._message_queue.get() + return msg + + async def client_send(msg): + if isinstance(msg, bytes): + await self.send(bytes_data=msg) + else: + await self.send(text_data=str(msg)) + + async def client_close(code: int, reason: str = ""): + await self.close(code=code) + + await proxy.forward_websocket( + path=self._path, + query_string=self._query_string, + client_receive=client_receive, + client_send=client_send, + client_close=client_close, + subprotocol=self._subprotocol, + ) + + async def receive(self, text_data: str = None, bytes_data: bytes = None) -> None: # type: ignore[assignment] + await self._message_queue.put(text_data or bytes_data) + + async def disconnect(self, close_code: int) -> None: + await self._message_queue.put(None) + + patterns.websocket_consumer = DeepgramProxyConsumer # type: ignore[attr-defined] + + except ImportError: + pass # Django Channels not installed + + return patterns diff --git a/src/deepgram/proxy/adapters/fastapi.py b/src/deepgram/proxy/adapters/fastapi.py new file mode 100644 index 00000000..9b5631e6 --- /dev/null +++ b/src/deepgram/proxy/adapters/fastapi.py @@ -0,0 +1,98 @@ +"""FastAPI adapter for the Deepgram proxy.""" + +from typing import TYPE_CHECKING + +from ..errors import ProxyError + +if TYPE_CHECKING: + from ..engine import DeepgramProxy + from fastapi import APIRouter + + +def create_deepgram_router(proxy: "DeepgramProxy") -> "APIRouter": + """Create a FastAPI APIRouter that proxies requests to Deepgram. + + Usage:: + + from fastapi import FastAPI + from deepgram.proxy import DeepgramProxy + from deepgram.proxy.adapters.fastapi import create_deepgram_router + + proxy = DeepgramProxy(api_key="dg-xxx") + app = FastAPI() + app.include_router(create_deepgram_router(proxy), prefix="/deepgram") + """ + from fastapi import APIRouter, Request, Response, WebSocket, WebSocketDisconnect + + router = APIRouter() + + @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) + async def proxy_rest(request: Request, path: str) -> Response: + full_path = f"/{path}" + authorization = request.headers.get("authorization") + + try: + scopes = proxy.authenticate(authorization) + proxy.authorize(full_path, scopes) + except ProxyError as exc: + return Response(content=exc.message, status_code=exc.status_code) + + headers = dict(request.headers) + body = await request.body() + query_string = str(request.query_params) + + try: + status, resp_headers, resp_body = await proxy.forward_rest_async( + method=request.method, + path=full_path, + headers=headers, + query_string=query_string, + body=body, + ) + except ProxyError as exc: + return Response(content=exc.message, status_code=exc.status_code) + + return Response(content=resp_body, status_code=status, headers=resp_headers) + + @router.websocket("/{path:path}") + async def proxy_websocket(ws: WebSocket, path: str) -> None: + full_path = f"/{path}" + + # Extract subprotocol from Sec-WebSocket-Protocol header + subprotocol = ws.headers.get("sec-websocket-protocol") + + await ws.accept(subprotocol=subprotocol) + + async def client_receive(): + try: + data = await ws.receive() + if data.get("type") == "websocket.disconnect": + return None + return data.get("text") or data.get("bytes") + except WebSocketDisconnect: + return None + + async def client_send(msg): + if isinstance(msg, bytes): + await ws.send_bytes(msg) + else: + await ws.send_text(str(msg)) + + async def client_close(code: int, reason: str = ""): + try: + await ws.close(code=code, reason=reason) + except Exception: + pass + + query_string = str(ws.query_params) + + await proxy.forward_websocket( + path=full_path, + query_string=query_string, + client_receive=client_receive, + client_send=client_send, + client_close=client_close, + subprotocol=subprotocol, + ) + + return router diff --git a/src/deepgram/proxy/adapters/flask.py b/src/deepgram/proxy/adapters/flask.py new file mode 100644 index 00000000..d08dbac6 --- /dev/null +++ b/src/deepgram/proxy/adapters/flask.py @@ -0,0 +1,111 @@ +"""Flask adapter for the Deepgram proxy.""" + +import asyncio +from typing import TYPE_CHECKING + +from ..errors import ProxyError + +if TYPE_CHECKING: + from ..engine import DeepgramProxy + from flask import Blueprint + + +def create_deepgram_blueprint(proxy: "DeepgramProxy") -> "Blueprint": + """Create a Flask Blueprint that proxies requests to Deepgram. + + REST requests use synchronous forwarding. WebSocket support requires + ``flask-sock`` (``pip install flask-sock``). + + Usage:: + + from flask import Flask + from deepgram.proxy import DeepgramProxy + from deepgram.proxy.adapters.flask import create_deepgram_blueprint + + proxy = DeepgramProxy(api_key="dg-xxx") + app = Flask(__name__) + app.register_blueprint(create_deepgram_blueprint(proxy), url_prefix="/deepgram") + """ + from flask import Blueprint, Response, request + + bp = Blueprint("deepgram_proxy", __name__) + + @bp.route("/", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) + def proxy_rest(path: str) -> Response: + full_path = f"/{path}" + authorization = request.headers.get("Authorization") + + try: + scopes = proxy.authenticate(authorization) + proxy.authorize(full_path, scopes) + except ProxyError as exc: + return Response(exc.message, status=exc.status_code) + + headers = dict(request.headers) + body = request.get_data() + query_string = request.query_string.decode("utf-8") + + try: + status, resp_headers, resp_body = proxy.forward_rest_sync( + method=request.method, + path=full_path, + headers=headers, + query_string=query_string, + body=body, + ) + except ProxyError as exc: + return Response(exc.message, status=exc.status_code) + + return Response(resp_body, status=status, headers=resp_headers) + + # Optional WebSocket support via flask-sock + try: + from flask_sock import Sock + + sock = Sock() + + @sock.route("/", bp=bp) + def proxy_websocket(ws, path: str) -> None: # type: ignore[no-untyped-def] + full_path = f"/{path}" + + # flask-sock doesn't expose subprotocol headers easily; + # read from the underlying environ + subprotocol = ws.environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL") + query_string = ws.environ.get("QUERY_STRING", "") + + loop = asyncio.new_event_loop() + + async def client_receive(): + try: + data = await loop.run_in_executor(None, ws.receive) + return data + except Exception: + return None + + async def client_send(msg): + await loop.run_in_executor(None, ws.send, msg) + + async def client_close(code: int, reason: str = ""): + try: + await loop.run_in_executor(None, ws.close, code, reason) + except Exception: + pass + + try: + loop.run_until_complete( + proxy.forward_websocket( + path=full_path, + query_string=query_string, + client_receive=client_receive, + client_send=client_send, + client_close=client_close, + subprotocol=subprotocol, + ) + ) + finally: + loop.close() + + except ImportError: + pass # flask-sock not installed; WS support unavailable + + return bp diff --git a/src/deepgram/proxy/engine.py b/src/deepgram/proxy/engine.py new file mode 100644 index 00000000..119d5f91 --- /dev/null +++ b/src/deepgram/proxy/engine.py @@ -0,0 +1,355 @@ +"""Core proxy engine for forwarding requests to Deepgram.""" + +import asyncio +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import httpx +from .errors import AuthenticationError, AuthorizationError, UpstreamError +from .scopes import Scope, get_target_base_url, path_matches_any_scope + +# Headers that should not be forwarded to upstream +_HOP_BY_HOP = frozenset( + { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + "host", + "authorization", + } +) + +# Default max body size: 200 MB +_DEFAULT_MAX_BODY = 200 * 1024 * 1024 + + +class DeepgramProxy: + """Proxy that authenticates clients and forwards requests to Deepgram. + + Args: + api_key: Deepgram API key. Falls back to ``DEEPGRAM_API_KEY`` env var. + require_auth: If True (default), requests must carry a valid JWT. + production_url: Override the default ``https://api.deepgram.com`` base URL. + agent_url: Override the default ``https://agent.deepgram.com`` base URL. + timeout: HTTP timeout in seconds for upstream requests. + max_body_size: Maximum request body size in bytes. + """ + + def __init__( + self, + api_key: Optional[str] = None, + *, + require_auth: bool = True, + production_url: Optional[str] = None, + agent_url: Optional[str] = None, + timeout: float = 60.0, + max_body_size: int = _DEFAULT_MAX_BODY, + ): + resolved_key = api_key or os.environ.get("DEEPGRAM_API_KEY", "") + if not resolved_key: + raise ValueError("api_key is required (or set DEEPGRAM_API_KEY env var)") + self.api_key: str = resolved_key + + self.require_auth = require_auth + self.production_url = production_url + self.agent_url = agent_url + self.timeout = timeout + self.max_body_size = max_body_size + + self._jwt_manager: Any = None + self._sync_client: Optional[httpx.Client] = None + self._async_client: Optional[httpx.AsyncClient] = None + + @property + def jwt_manager(self) -> Any: + """Lazily initialise the JWT manager (defers PyJWT import).""" + if self._jwt_manager is None: + from .jwt import JWTManager + + self._jwt_manager = JWTManager(self.api_key) + return self._jwt_manager + + # ------------------------------------------------------------------ + # Token helpers + # ------------------------------------------------------------------ + + def create_token(self, scopes: List[Scope], expires_in: int = 3600) -> str: + """Create a signed JWT for client-side use. + + Args: + scopes: Scopes the token grants (e.g. ``[Scope.LISTEN, Scope.SPEAK]``). + expires_in: Token lifetime in seconds. + + Returns: + Encoded JWT string. + """ + return self.jwt_manager.create_token(scopes, expires_in) + + # ------------------------------------------------------------------ + # Auth + # ------------------------------------------------------------------ + + def authenticate(self, authorization: Optional[str]) -> Optional[List[Scope]]: + """Validate a Bearer JWT from an Authorization header. + + Returns: + List of granted scopes, or None if auth is not required and no token + was provided. + + Raises: + AuthenticationError: If a token is required but missing/invalid. + """ + from .jwt import JWTManager + + token = JWTManager.extract_token_from_header(authorization) + + if token is None: + if self.require_auth: + raise AuthenticationError("Missing Authorization header") + return None + + try: + payload = self.jwt_manager.validate_token(token) + except Exception as exc: + raise AuthenticationError(f"Invalid token: {exc}") from exc + + return [Scope(s) for s in payload.scopes if s in Scope._value2member_map_] + + def authorize(self, path: str, scopes: Optional[List[Scope]]) -> None: + """Check that scopes permit accessing *path*. + + When *scopes* is None (unauthenticated, auth not required), access is + allowed to all paths. + + Raises: + AuthorizationError: If the token's scopes don't cover the path. + """ + if scopes is None: + return + if not path_matches_any_scope(path, scopes): + raise AuthorizationError(f"Token scopes {[s.value for s in scopes]} do not permit access to {path}") + + # ------------------------------------------------------------------ + # REST forwarding + # ------------------------------------------------------------------ + + def _prepare_upstream( + self, + method: str, + path: str, + headers: Dict[str, str], + query_string: str = "", + body: bytes = b"", + ) -> Tuple[str, Dict[str, str], str]: + """Build the upstream URL and sanitised headers.""" + base = get_target_base_url(path, self.production_url, self.agent_url) + url = f"{base}{path}" + if query_string: + url = f"{url}?{query_string}" + + out_headers: Dict[str, str] = {} + for k, v in headers.items(): + if k.lower() not in _HOP_BY_HOP: + out_headers[k] = v + out_headers["Authorization"] = f"Token {self.api_key}" + + return url, out_headers, method.upper() + + def forward_rest_sync( + self, + method: str, + path: str, + headers: Dict[str, str], + query_string: str = "", + body: bytes = b"", + ) -> Tuple[int, Dict[str, str], bytes]: + """Synchronously forward an HTTP request to Deepgram. + + Returns: + ``(status_code, response_headers, response_body)`` + """ + url, out_headers, method = self._prepare_upstream(method, path, headers, query_string, body) + + if self._sync_client is None: + self._sync_client = httpx.Client(timeout=self.timeout) + + try: + resp = self._sync_client.request(method, url, headers=out_headers, content=body) + except httpx.ConnectError as exc: + raise UpstreamError("Failed to connect to Deepgram", status_code=502, detail=str(exc)) from exc + except httpx.TimeoutException as exc: + raise UpstreamError("Upstream request timed out", status_code=504, detail=str(exc)) from exc + + resp_headers = dict(resp.headers) + # Remove hop-by-hop from response too + for h in ("transfer-encoding", "connection"): + resp_headers.pop(h, None) + + return resp.status_code, resp_headers, resp.content + + async def forward_rest_async( + self, + method: str, + path: str, + headers: Dict[str, str], + query_string: str = "", + body: bytes = b"", + ) -> Tuple[int, Dict[str, str], bytes]: + """Asynchronously forward an HTTP request to Deepgram. + + Returns: + ``(status_code, response_headers, response_body)`` + """ + url, out_headers, method = self._prepare_upstream(method, path, headers, query_string, body) + + if self._async_client is None: + self._async_client = httpx.AsyncClient(timeout=self.timeout) + + try: + resp = await self._async_client.request(method, url, headers=out_headers, content=body) + except httpx.ConnectError as exc: + raise UpstreamError("Failed to connect to Deepgram", status_code=502, detail=str(exc)) from exc + except httpx.TimeoutException as exc: + raise UpstreamError("Upstream request timed out", status_code=504, detail=str(exc)) from exc + + resp_headers = dict(resp.headers) + for h in ("transfer-encoding", "connection"): + resp_headers.pop(h, None) + + return resp.status_code, resp_headers, resp.content + + # ------------------------------------------------------------------ + # WebSocket forwarding + # ------------------------------------------------------------------ + + async def forward_websocket( + self, + path: str, + query_string: str, + client_receive: Callable, + client_send: Callable, + client_close: Callable, + subprotocol: Optional[str] = None, + ) -> None: + """Bidirectional WebSocket relay between client and Deepgram. + + Auth is handled via the ``subprotocol`` value: + - ``"proxy,"`` — validate JWT, connect upstream with ``token,`` + - ``"token,"`` / ``"bearer,"`` — passthrough to Deepgram + - None + require_auth → close 4003 + + Args: + path: The API path (e.g. ``/v1/listen``). + query_string: URL query string to forward. + client_receive: Async callable that returns the next message from the client. + client_send: Async callable that sends a message to the client. + client_close: Async callable that closes the client connection, accepts (code, reason). + subprotocol: The ``Sec-WebSocket-Protocol`` value from the handshake. + """ + try: + import websockets + except ImportError: + raise ImportError( + "The 'websockets' package is required for WebSocket proxying. " + "Install it with: pip install websockets" + ) + + upstream_subprotocol: Optional[str] = None + scopes: Optional[List[Scope]] = None + + if subprotocol: + if subprotocol.startswith("proxy,"): + token = subprotocol[len("proxy,") :] + try: + payload = self.jwt_manager.validate_token(token) + except Exception: + await client_close(4003, "Invalid token") + return + + scopes = [Scope(s) for s in payload.scopes if s in Scope._value2member_map_] + try: + self.authorize(path, scopes) + except AuthorizationError: + await client_close(4003, "Insufficient scope") + return + + upstream_subprotocol = f"token,{self.api_key}" + + elif subprotocol.startswith(("token,", "bearer,")): + # Direct passthrough — forward the client's subprotocol as-is + upstream_subprotocol = subprotocol + else: + if self.require_auth: + await client_close(4003, "Unrecognised subprotocol") + return + else: + if self.require_auth: + await client_close(4003, "Authentication required") + return + + base = get_target_base_url(path, self.production_url, self.agent_url) + ws_base = base.replace("https://", "wss://").replace("http://", "ws://") + upstream_url = f"{ws_base}{path}" + if query_string: + upstream_url = f"{upstream_url}?{query_string}" + + extra_headers = {"Authorization": f"Token {self.api_key}"} + + connect_kwargs: dict = { + "additional_headers": extra_headers, + } + if upstream_subprotocol: + connect_kwargs["subprotocols"] = [upstream_subprotocol] + + try: + async with websockets.connect(upstream_url, **connect_kwargs) as upstream: + + async def client_to_upstream() -> None: + try: + while True: + msg = await client_receive() + if msg is None: + break + await upstream.send(msg) + except Exception: + pass + + async def upstream_to_client() -> None: + try: + async for msg in upstream: + await client_send(msg) + except Exception: + pass + + tasks = [ + asyncio.create_task(client_to_upstream()), + asyncio.create_task(upstream_to_client()), + ] + # Wait until either side disconnects, then cancel the other + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for t in pending: + t.cancel() + + except Exception as exc: + await client_close(1011, f"Proxy error: {exc}") + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def close(self) -> None: + """Close the synchronous HTTP client.""" + if self._sync_client: + self._sync_client.close() + self._sync_client = None + + async def aclose(self) -> None: + """Close the asynchronous HTTP client.""" + if self._async_client: + await self._async_client.aclose() + self._async_client = None diff --git a/src/deepgram/proxy/errors.py b/src/deepgram/proxy/errors.py new file mode 100644 index 00000000..f8f8f535 --- /dev/null +++ b/src/deepgram/proxy/errors.py @@ -0,0 +1,32 @@ +"""Exception classes for the Deepgram proxy.""" + + +class ProxyError(Exception): + """Base exception for proxy errors.""" + + def __init__(self, message: str, status_code: int = 500, detail: str = ""): + self.message = message + self.status_code = status_code + self.detail = detail + super().__init__(message) + + +class AuthenticationError(ProxyError): + """Raised when JWT is missing, invalid, or expired.""" + + def __init__(self, message: str = "Authentication required", detail: str = ""): + super().__init__(message=message, status_code=401, detail=detail) + + +class AuthorizationError(ProxyError): + """Raised when the token's scopes don't permit the requested path.""" + + def __init__(self, message: str = "Insufficient scope", detail: str = ""): + super().__init__(message=message, status_code=403, detail=detail) + + +class UpstreamError(ProxyError): + """Raised when Deepgram returns an error or is unreachable.""" + + def __init__(self, message: str = "Upstream error", status_code: int = 502, detail: str = ""): + super().__init__(message=message, status_code=status_code, detail=detail) diff --git a/src/deepgram/proxy/jwt.py b/src/deepgram/proxy/jwt.py new file mode 100644 index 00000000..7eb47c9a --- /dev/null +++ b/src/deepgram/proxy/jwt.py @@ -0,0 +1,86 @@ +"""JWT creation and validation for the Deepgram proxy. + +Uses HMAC-SHA256 with the Deepgram API key as the signing secret. +Requires PyJWT (``pip install PyJWT``). +""" + +import time +import uuid +from typing import List, NamedTuple, Optional + +try: + import jwt as pyjwt +except ImportError: + raise ImportError( + "PyJWT is required for proxy JWT support. " + "Install it with: pip install 'deepgram-sdk[proxy]' or pip install PyJWT" + ) + +from .scopes import Scope + + +class TokenPayload(NamedTuple): + """Decoded JWT payload.""" + + scopes: List[str] + exp: int + iat: int + jti: str + + +class JWTManager: + """Creates and validates HMAC-SHA256 JWTs signed with the Deepgram API key.""" + + def __init__(self, api_key: str): + self._secret = api_key + + def create_token(self, scopes: List[Scope], expires_in: int = 3600) -> str: + """Create a signed JWT with the given scopes. + + Args: + scopes: List of Scope values the token is permitted to use. + expires_in: Token lifetime in seconds (default 3600). + + Returns: + Encoded JWT string. + """ + now = int(time.time()) + payload = { + "scopes": [s.value if isinstance(s, Scope) else s for s in scopes], + "iat": now, + "exp": now + expires_in, + "jti": str(uuid.uuid4()), + } + return pyjwt.encode(payload, self._secret, algorithm="HS256") + + def validate_token(self, token: str) -> TokenPayload: + """Validate a JWT and return its payload. + + Raises: + jwt.ExpiredSignatureError: If the token has expired. + jwt.InvalidTokenError: If the token is malformed or signature is invalid. + """ + data = pyjwt.decode(token, self._secret, algorithms=["HS256"]) + return TokenPayload( + scopes=data.get("scopes", []), + exp=data["exp"], + iat=data["iat"], + jti=data["jti"], + ) + + @staticmethod + def extract_token_from_header(authorization: Optional[str]) -> Optional[str]: + """Extract a bearer token from an Authorization header value. + + Args: + authorization: The full header value, e.g. ``"Bearer "``. + + Returns: + The token string, or None if the header is missing/malformed. + """ + if not authorization: + return None + parts = authorization.split(None, 1) + if len(parts) == 2 and parts[0].lower() == "bearer": + return parts[1] + return None diff --git a/src/deepgram/proxy/scopes.py b/src/deepgram/proxy/scopes.py new file mode 100644 index 00000000..27bc34ad --- /dev/null +++ b/src/deepgram/proxy/scopes.py @@ -0,0 +1,73 @@ +"""Scope definitions and path matching for the Deepgram proxy.""" + +import re +from enum import Enum +from typing import List, Optional + + +class Scope(str, Enum): + """Scopes that can be granted to proxy JWT tokens.""" + + LISTEN = "listen" + SPEAK = "speak" + READ = "read" + AGENT = "agent" + MANAGE = "manage" + SELF_HOSTED = "self_hosted" + + +# Maps each scope to regex patterns that match permitted API paths. +# Patterns use v\d+ to be version-agnostic. +SCOPE_PATH_PATTERNS: dict = { + Scope.LISTEN: [ + re.compile(r"^/v\d+/listen"), + ], + Scope.SPEAK: [ + re.compile(r"^/v\d+/speak"), + ], + Scope.READ: [ + re.compile(r"^/v\d+/read"), + ], + Scope.AGENT: [ + re.compile(r"^/v\d+/agent"), + ], + Scope.MANAGE: [ + re.compile(r"^/v\d+/projects"), + re.compile(r"^/v\d+/keys"), + re.compile(r"^/v\d+/members"), + re.compile(r"^/v\d+/scopes"), + re.compile(r"^/v\d+/invitations"), + re.compile(r"^/v\d+/usage"), + re.compile(r"^/v\d+/billing"), + re.compile(r"^/v\d+/balances"), + re.compile(r"^/v\d+/models"), + ], + Scope.SELF_HOSTED: [ + re.compile(r"^/v\d+/onprem"), + re.compile(r"^/v\d+/selfhosted"), + ], +} + +# Paths routed to agent.deepgram.com instead of api.deepgram.com +_AGENT_PATH_PATTERN = re.compile(r"^/v\d+/agent") + + +def path_matches_scope(path: str, scope: Scope) -> bool: + """Check if a request path is permitted by a single scope.""" + patterns = SCOPE_PATH_PATTERNS.get(scope, []) + return any(p.search(path) for p in patterns) + + +def path_matches_any_scope(path: str, scopes: List[Scope]) -> bool: + """Check if a request path is permitted by any of the given scopes.""" + return any(path_matches_scope(path, s) for s in scopes) + + +def get_target_base_url(path: str, production_url: Optional[str] = None, agent_url: Optional[str] = None) -> str: + """Return the upstream Deepgram base URL for a given path. + + Agent paths route to agent.deepgram.com; everything else to api.deepgram.com. + """ + if _AGENT_PATH_PATTERN.search(path): + return agent_url or "https://agent.deepgram.com" + return production_url or "https://api.deepgram.com" diff --git a/tests/custom/test_proxy_engine.py b/tests/custom/test_proxy_engine.py new file mode 100644 index 00000000..0031595b --- /dev/null +++ b/tests/custom/test_proxy_engine.py @@ -0,0 +1,167 @@ +"""Tests for the core DeepgramProxy engine.""" + +import httpx +import pytest + +from deepgram.proxy import DeepgramProxy, Scope +from deepgram.proxy.errors import AuthenticationError, AuthorizationError, UpstreamError + +API_KEY = "test-api-key-for-engine" + + +@pytest.fixture +def proxy(): + p = DeepgramProxy(api_key=API_KEY, require_auth=True) + yield p + p.close() + + +@pytest.fixture +def proxy_no_auth(): + p = DeepgramProxy(api_key=API_KEY, require_auth=False) + yield p + p.close() + + +class TestInit: + def test_requires_api_key(self): + import os + old = os.environ.pop("DEEPGRAM_API_KEY", None) + try: + with pytest.raises(ValueError, match="api_key is required"): + DeepgramProxy(api_key="") + finally: + if old: + os.environ["DEEPGRAM_API_KEY"] = old + + def test_env_var_fallback(self): + import os + os.environ["DEEPGRAM_API_KEY"] = "env-key" + try: + p = DeepgramProxy() + assert p.api_key == "env-key" + finally: + del os.environ["DEEPGRAM_API_KEY"] + + +class TestCreateToken: + def test_returns_string(self, proxy): + token = proxy.create_token([Scope.LISTEN]) + assert isinstance(token, str) + + +class TestAuthenticate: + def test_missing_header_required(self, proxy): + with pytest.raises(AuthenticationError, match="Missing"): + proxy.authenticate(None) + + def test_missing_header_not_required(self, proxy_no_auth): + result = proxy_no_auth.authenticate(None) + assert result is None + + def test_valid_token(self, proxy): + token = proxy.create_token([Scope.LISTEN, Scope.SPEAK]) + scopes = proxy.authenticate(f"Bearer {token}") + assert Scope.LISTEN in scopes + assert Scope.SPEAK in scopes + + def test_invalid_token(self, proxy): + with pytest.raises(AuthenticationError, match="Invalid token"): + proxy.authenticate("Bearer bad.token.here") + + +class TestAuthorize: + def test_permits_matching_scope(self, proxy): + # Should not raise + proxy.authorize("/v1/listen", [Scope.LISTEN]) + + def test_rejects_wrong_scope(self, proxy): + with pytest.raises(AuthorizationError, match="do not permit"): + proxy.authorize("/v1/listen", [Scope.SPEAK]) + + def test_none_scopes_allows_all(self, proxy): + # None means auth was not required and no token was provided + proxy.authorize("/v1/listen", None) + + +class TestForwardRestSync: + def test_strips_auth_header_and_injects_api_key(self, proxy): + """Verify the proxy replaces Authorization header with its own.""" + captured = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["auth"] = request.headers.get("authorization") + captured["url"] = str(request.url) + return httpx.Response(200, content=b'{"ok": true}') + + proxy._sync_client = httpx.Client(transport=httpx.MockTransport(handler)) + + status, headers, body = proxy.forward_rest_sync( + method="POST", + path="/v1/listen", + headers={"Authorization": "Bearer client-jwt", "Content-Type": "application/json"}, + query_string="model=nova-3", + body=b'{"url": "https://example.com/audio.wav"}', + ) + + assert status == 200 + assert captured["auth"] == f"Token {API_KEY}" + assert "api.deepgram.com" in captured["url"] + assert "model=nova-3" in captured["url"] + + def test_upstream_error_passthrough(self, proxy): + """Upstream 4xx/5xx are returned as-is.""" + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(400, content=b"Bad Request") + + proxy._sync_client = httpx.Client(transport=httpx.MockTransport(handler)) + + status, headers, body = proxy.forward_rest_sync( + method="POST", path="/v1/listen", headers={}, body=b"", + ) + assert status == 400 + assert body == b"Bad Request" + + def test_connect_error_raises_upstream_error(self, proxy): + def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("Connection refused") + + proxy._sync_client = httpx.Client(transport=httpx.MockTransport(handler)) + + with pytest.raises(UpstreamError, match="Failed to connect"): + proxy.forward_rest_sync(method="GET", path="/v1/listen", headers={}) + + def test_agent_path_routes_to_agent_host(self, proxy): + captured = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + return httpx.Response(200, content=b"ok") + + proxy._sync_client = httpx.Client(transport=httpx.MockTransport(handler)) + proxy.forward_rest_sync(method="POST", path="/v1/agent", headers={}) + + assert "agent.deepgram.com" in captured["url"] + + +class TestForwardRestAsync: + @pytest.mark.asyncio + async def test_async_forward(self, proxy): + captured = {} + + async def handler(request: httpx.Request) -> httpx.Response: + captured["auth"] = request.headers.get("authorization") + return httpx.Response(200, content=b'{"result": "ok"}') + + proxy._async_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + + status, headers, body = await proxy.forward_rest_async( + method="POST", + path="/v1/speak", + headers={"Authorization": "Bearer jwt"}, + body=b"Hello world", + ) + + assert status == 200 + assert captured["auth"] == f"Token {API_KEY}" diff --git a/tests/custom/test_proxy_fastapi.py b/tests/custom/test_proxy_fastapi.py new file mode 100644 index 00000000..071ef7c4 --- /dev/null +++ b/tests/custom/test_proxy_fastapi.py @@ -0,0 +1,87 @@ +"""End-to-end tests for the FastAPI proxy adapter.""" + +import httpx +import pytest + +from deepgram.proxy import DeepgramProxy, Scope + +API_KEY = "test-api-key-for-fastapi" + + +@pytest.fixture +def app(): + """Create a FastAPI app with the proxy router and mocked upstream.""" + from fastapi import FastAPI + + from deepgram.proxy.adapters.fastapi import create_deepgram_router + + proxy = DeepgramProxy(api_key=API_KEY, require_auth=True) + router = create_deepgram_router(proxy) + + application = FastAPI() + application.include_router(router, prefix="/deepgram") + + # Mock the async HTTP client on the proxy + def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + content=b'{"results": "ok"}', + headers={"content-type": "application/json"}, + ) + + proxy._async_client = httpx.AsyncClient(transport=httpx.MockTransport(mock_handler)) + + application.state.proxy = proxy + return application + + +@pytest.fixture +def client(app): + from starlette.testclient import TestClient + return TestClient(app) + + +class TestRESTProxy: + def _make_token(self): + proxy = DeepgramProxy(api_key=API_KEY) + return proxy.create_token([Scope.LISTEN, Scope.SPEAK]) + + def test_authenticated_request(self, client): + token = self._make_token() + resp = client.post( + "/deepgram/v1/listen", + headers={"Authorization": f"Bearer {token}"}, + content=b"audio data", + ) + assert resp.status_code == 200 + assert resp.json() == {"results": "ok"} + + def test_missing_auth(self, client): + resp = client.post("/deepgram/v1/listen", content=b"audio data") + assert resp.status_code == 401 + + def test_invalid_token(self, client): + resp = client.post( + "/deepgram/v1/listen", + headers={"Authorization": "Bearer invalid.jwt.token"}, + content=b"audio data", + ) + assert resp.status_code == 401 + + def test_scope_mismatch(self, client): + """Token scoped to LISTEN can't access /v1/agent.""" + proxy = DeepgramProxy(api_key=API_KEY) + token = proxy.create_token([Scope.LISTEN]) # no AGENT scope + resp = client.post( + "/deepgram/v1/agent", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 403 + + def test_get_request(self, client): + token = self._make_token() + resp = client.get( + "/deepgram/v1/listen", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 diff --git a/tests/custom/test_proxy_jwt.py b/tests/custom/test_proxy_jwt.py new file mode 100644 index 00000000..2440433f --- /dev/null +++ b/tests/custom/test_proxy_jwt.py @@ -0,0 +1,92 @@ +"""Tests for proxy JWT creation and validation.""" + + +import jwt as pyjwt +import pytest + +from deepgram.proxy.jwt import JWTManager, TokenPayload +from deepgram.proxy.scopes import Scope + +API_KEY = "test-api-key-for-jwt" + + +@pytest.fixture +def manager(): + return JWTManager(API_KEY) + + +class TestCreateToken: + def test_returns_string(self, manager): + token = manager.create_token([Scope.LISTEN]) + assert isinstance(token, str) + + def test_contains_scopes(self, manager): + token = manager.create_token([Scope.LISTEN, Scope.SPEAK]) + payload = pyjwt.decode(token, API_KEY, algorithms=["HS256"]) + assert payload["scopes"] == ["listen", "speak"] + + def test_contains_exp(self, manager): + token = manager.create_token([Scope.LISTEN], expires_in=600) + payload = pyjwt.decode(token, API_KEY, algorithms=["HS256"]) + assert payload["exp"] - payload["iat"] == 600 + + def test_contains_jti(self, manager): + token = manager.create_token([Scope.LISTEN]) + payload = pyjwt.decode(token, API_KEY, algorithms=["HS256"]) + assert "jti" in payload + assert len(payload["jti"]) > 0 + + def test_unique_jti(self, manager): + t1 = manager.create_token([Scope.LISTEN]) + t2 = manager.create_token([Scope.LISTEN]) + p1 = pyjwt.decode(t1, API_KEY, algorithms=["HS256"]) + p2 = pyjwt.decode(t2, API_KEY, algorithms=["HS256"]) + assert p1["jti"] != p2["jti"] + + +class TestValidateToken: + def test_valid_token(self, manager): + token = manager.create_token([Scope.LISTEN]) + payload = manager.validate_token(token) + assert isinstance(payload, TokenPayload) + assert payload.scopes == ["listen"] + + def test_expired_token(self, manager): + token = manager.create_token([Scope.LISTEN], expires_in=-1) + with pytest.raises(pyjwt.ExpiredSignatureError): + manager.validate_token(token) + + def test_bad_signature(self, manager): + other = JWTManager("wrong-key") + token = other.create_token([Scope.LISTEN]) + with pytest.raises(pyjwt.InvalidSignatureError): + manager.validate_token(token) + + def test_malformed_token(self, manager): + with pytest.raises(pyjwt.DecodeError): + manager.validate_token("not.a.valid.jwt") + + def test_multiple_scopes(self, manager): + token = manager.create_token([Scope.LISTEN, Scope.SPEAK, Scope.AGENT]) + payload = manager.validate_token(token) + assert payload.scopes == ["listen", "speak", "agent"] + + +class TestExtractTokenFromHeader: + def test_bearer_token(self): + assert JWTManager.extract_token_from_header("Bearer abc123") == "abc123" + + def test_bearer_lowercase(self): + assert JWTManager.extract_token_from_header("bearer abc123") == "abc123" + + def test_no_header(self): + assert JWTManager.extract_token_from_header(None) is None + + def test_empty_header(self): + assert JWTManager.extract_token_from_header("") is None + + def test_wrong_scheme(self): + assert JWTManager.extract_token_from_header("Token abc123") is None + + def test_no_space(self): + assert JWTManager.extract_token_from_header("Bearerabc123") is None diff --git a/tests/custom/test_proxy_scopes.py b/tests/custom/test_proxy_scopes.py new file mode 100644 index 00000000..adde91d7 --- /dev/null +++ b/tests/custom/test_proxy_scopes.py @@ -0,0 +1,78 @@ +"""Tests for proxy scope definitions and path matching.""" + + +from deepgram.proxy.scopes import ( + Scope, + get_target_base_url, + path_matches_any_scope, + path_matches_scope, +) + + +class TestPathMatchesScope: + def test_listen_v1(self): + assert path_matches_scope("/v1/listen", Scope.LISTEN) + + def test_listen_v2(self): + assert path_matches_scope("/v2/listen", Scope.LISTEN) + + def test_listen_with_subpath(self): + assert path_matches_scope("/v1/listen/stream", Scope.LISTEN) + + def test_speak(self): + assert path_matches_scope("/v1/speak", Scope.SPEAK) + + def test_read(self): + assert path_matches_scope("/v1/read", Scope.READ) + + def test_agent(self): + assert path_matches_scope("/v1/agent", Scope.AGENT) + + def test_manage_projects(self): + assert path_matches_scope("/v1/projects", Scope.MANAGE) + + def test_manage_usage(self): + assert path_matches_scope("/v1/usage", Scope.MANAGE) + + def test_self_hosted(self): + assert path_matches_scope("/v1/onprem", Scope.SELF_HOSTED) + + def test_no_match_wrong_scope(self): + assert not path_matches_scope("/v1/listen", Scope.SPEAK) + + def test_no_match_unrecognised_path(self): + assert not path_matches_scope("/v1/unknown", Scope.LISTEN) + + +class TestPathMatchesAnyScope: + def test_matches_first(self): + assert path_matches_any_scope("/v1/listen", [Scope.LISTEN, Scope.SPEAK]) + + def test_matches_second(self): + assert path_matches_any_scope("/v1/speak", [Scope.LISTEN, Scope.SPEAK]) + + def test_no_match(self): + assert not path_matches_any_scope("/v1/agent", [Scope.LISTEN, Scope.SPEAK]) + + def test_empty_scopes(self): + assert not path_matches_any_scope("/v1/listen", []) + + +class TestGetTargetBaseUrl: + def test_default_api(self): + assert get_target_base_url("/v1/listen") == "https://api.deepgram.com" + + def test_agent_path(self): + assert get_target_base_url("/v1/agent") == "https://agent.deepgram.com" + + def test_agent_subpath(self): + assert get_target_base_url("/v1/agent/sessions") == "https://agent.deepgram.com" + + def test_custom_production_url(self): + assert get_target_base_url("/v1/listen", production_url="https://custom.api.com") == "https://custom.api.com" + + def test_custom_agent_url(self): + assert get_target_base_url("/v1/agent", agent_url="https://custom.agent.com") == "https://custom.agent.com" + + def test_speak_goes_to_api(self): + assert get_target_base_url("/v1/speak") == "https://api.deepgram.com"