Source code for oidc_provider_mock._storage

from collections import deque
from collections.abc import Collection, Iterable, Sequence
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import ClassVar, Literal, cast, override

import authlib.oauth2.rfc6749
import authlib.oidc.core
import flask
import joserfc.jwk
import werkzeug.local


class ClientAllowAny:
    """Special value for client fields that skips validation of the field."""

    def __repr__(self) -> str:
        return str(type(self).__name__)


type ClientAuthMethod = Literal["none", "client_secret_basic", "client_secret_post"]


@dataclass(kw_only=True, frozen=True)
class Client(authlib.oauth2.rfc6749.ClientMixin):
    id: str
    secret: str | ClientAllowAny
    redirect_uris: Sequence[str] | ClientAllowAny
    allowed_scopes: Sequence[str]
    token_endpoint_auth_method: ClientAuthMethod | ClientAllowAny

    """Wrap ``Client`` to implement authlib’s client protocol."""

    RESPONSE_TYPES_SUPPORTED: ClassVar[tuple[str, ...]] = ("code",)
    GRANT_TYPES_SUPPORTED: ClassVar[tuple[str, ...]] = (
        "authorization_code",
        "refresh_token",
    )
    SCOPES_SUPPORTED: ClassVar[tuple[str, ...]] = (
        "openid",
        "profile",
        "email",
        "address",
        "phone",
    )

    @override
    def get_client_id(self):
        return self.id

    @override
    def get_default_redirect_uri(self) -> str | None:  # pyright: ignore[reportIncompatibleMethodOverride]
        if isinstance(self.redirect_uris, ClientAllowAny):
            return None

        return self.redirect_uris[0]

    @override
    def get_allowed_scope(self, scope: Collection[str] | str | None) -> str:
        if scope is None:
            scopes = ()
        elif isinstance(scope, str):
            scopes = scope.split()
        else:
            scopes = scope
        return " ".join(s for s in scopes if s in self.allowed_scopes)

    @override
    def check_redirect_uri(self, redirect_uri: str) -> bool:
        if isinstance(self.redirect_uris, ClientAllowAny):
            return True

        return redirect_uri in self.redirect_uris

    @override
    def check_client_secret(self, client_secret: str) -> bool:
        if isinstance(self.secret, ClientAllowAny):
            return True

        return client_secret == self.secret

    @override
    def check_endpoint_auth_method(self, method: str, endpoint: object):
        if isinstance(self.token_endpoint_auth_method, ClientAllowAny):
            return True

        return method == self.token_endpoint_auth_method

    @override
    def check_grant_type(self, grant_type: str):
        return grant_type in self.GRANT_TYPES_SUPPORTED

    @override
    def check_response_type(self, response_type: str):
        return response_type in self.RESPONSE_TYPES_SUPPORTED


[docs] @dataclass(kw_only=True, frozen=True) class User: #: Identifier ("subject") for the user sub: str #: Additional claims to be included in the ID token and ``user_info`` endpoint #: response. claims: dict[str, object] = field(default_factory=dict[str, object])
@dataclass(kw_only=True, frozen=True) class AuthorizationCode(authlib.oidc.core.AuthorizationCodeMixin): code: str client_id: str redirect_uri: str user_id: str scope: str nonce: str | None # Implement AuthorizationCodeMixin @override def get_redirect_uri(self): return self.redirect_uri @override def get_scope(self): return self.scope @override def get_nonce(self) -> str | None: return self.nonce @override def get_auth_time(self) -> int | None: return None @dataclass(kw_only=True, frozen=True) class AccessToken(authlib.oauth2.rfc6749.TokenMixin): token: str user_id: str scope: str expires_at: datetime def get_user(self) -> User: user = storage.get_user(self.user_id) if user is None: raise RuntimeError(f"Missing user {self.user_id} for access toke") return user # Implement `TokenMixin` @override def check_client(self, client: Client) -> bool: # Required only for revocation and refresh token endpoints. raise NotImplementedError() @override def is_expired(self): return datetime.now(UTC) >= self.expires_at @override def is_revoked(self): return False @override def get_scope(self) -> str: return self.scope @dataclass(kw_only=True, frozen=True) class RefreshToken(AccessToken): client_id: str access_token: str @override def check_client(self, client: Client) -> bool: return self.client_id == client.id class Storage: jwk: joserfc.jwk.RSAKey _clients: dict[str, Client] _users: dict[str, User] _authorization_codes: dict[str, AuthorizationCode] _access_tokens: dict[str, AccessToken] _refresh_tokens: dict[str, RefreshToken] _nonces: set[str] _recent_subjects: deque[str] def __init__(self) -> None: self.jwk = joserfc.jwk.RSAKey.generate_key(private=True) self._clients = {} self._users = {} self._authorization_codes = {} self._access_tokens = {} self._refresh_tokens = {} self._nonces = set() self._recent_subjects = deque() # User def get_user(self, sub: str) -> User | None: return self._users.get(sub) def store_user(self, user: User): self._users[user.sub] = user def get_recent_subjects(self) -> Sequence[str]: """Get a sequence of the 20 most recently recorded subjects, starting with the most recent one. """ return self._recent_subjects def record_subject(self, sub: str) -> None: try: self._recent_subjects.remove(sub) except ValueError: pass self._recent_subjects.appendleft(sub) if len(self._recent_subjects) > 20: self._recent_subjects.pop() # AuthorizationCodes def get_authorization_code(self, code: str) -> AuthorizationCode | None: return self._authorization_codes.get(code) def store_authorization_code(self, code: AuthorizationCode): self._authorization_codes[code.code] = code def remove_authorization_code(self, code: str) -> AuthorizationCode | None: return self._authorization_codes.pop(code, None) # AccessTokens def get_access_token(self, token: str) -> AccessToken | None: return self._access_tokens.get(token) def store_access_token(self, access_token: AccessToken): self._access_tokens[access_token.token] = access_token def remove_access_token(self, access_token: str) -> AccessToken | None: return self._access_tokens.pop(access_token, None) def access_tokens(self) -> Iterable[AccessToken]: return list(self._access_tokens.values()) # RefreshTokens def get_refresh_token(self, token: str) -> RefreshToken | None: return self._refresh_tokens.get(token) def store_refresh_token(self, refresh_token: RefreshToken): self._refresh_tokens[refresh_token.token] = refresh_token def remove_refresh_token(self, token: str) -> RefreshToken | None: return self._refresh_tokens.pop(token, None) def refresh_tokens(self) -> Iterable[RefreshToken]: return list(self._refresh_tokens.values()) # Client def get_client(self, id: str) -> Client | None: return self._clients.get(id) def store_client(self, client: Client): self._clients[client.id] = client # Nonce def add_nonce(self, nonce: str): self._nonces.add(nonce) def exists_nonce(self, nonce: str) -> bool: return nonce in self._nonces storage = cast( "Storage", werkzeug.local.LocalProxy(lambda: flask.g.oidc_provider_mock_storage) )