import logging
import secrets
import textwrap
import warnings
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from http import HTTPStatus
from typing import Never, cast, override
from urllib.parse import parse_qs, urlencode, urljoin, urlparse, urlunparse
from uuid import uuid4
import authlib.deprecate
import authlib.oauth2.rfc6749
import authlib.oauth2.rfc6749.errors
import authlib.oauth2.rfc6750
import authlib.oidc.core
import flask
import flask.typing
import joserfc.jwk
import pydantic
import werkzeug.exceptions
import werkzeug.local
from authlib.integrations import flask_oauth2
from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request
from authlib.oauth2 import OAuth2Error, OAuth2Request
from werkzeug.middleware.proxy_fix import ProxyFix
from . import _client
from ._storage import (
AccessToken,
AuthorizationCode,
Client,
ClientAllowAny,
ClientAuthMethod,
RefreshToken,
Storage,
User,
storage,
)
assert __package__
_logger = logging.getLogger(__package__)
_JWS_ALG = "RS256"
_authlib_version = tuple(int(x) for x in authlib.__version__.split(".")[:2])
class TokenValidator(authlib.oauth2.rfc6750.BearerTokenValidator):
@override
def authenticate_token(self, token_string: str):
token = storage.get_access_token(token_string)
if not token:
raise authlib.oauth2.rfc6749.AccessDeniedError()
return token
class AuthorizationCodeGrant(authlib.oauth2.rfc6749.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
@override
def query_authorization_code( # pyright: ignore[reportIncompatibleMethodOverride]
self, code: str, client: Client
) -> AuthorizationCode | None:
auth_code = storage.get_authorization_code(code)
if auth_code and auth_code.client_id == client.get_client_id():
return auth_code
@override
def delete_authorization_code(self, authorization_code: AuthorizationCode):
storage.remove_authorization_code(authorization_code.code)
@override
def authenticate_user(self, authorization_code: AuthorizationCode) -> User | None:
return storage.get_user(authorization_code.user_id)
@override
def save_authorization_code(self, code: str, request: object):
assert isinstance(request, OAuth2Request)
assert isinstance(request.user, User)
client = cast("Client", request.client)
with warnings.catch_warnings():
# Silence warnings for deprecated `OAuth2Request` properties.
warnings.simplefilter("ignore", authlib.deprecate.AuthlibDeprecationWarning)
assert isinstance(request.redirect_uri, str) # pyright: ignore[reportDeprecated]
storage.store_authorization_code(
AuthorizationCode(
code=code,
user_id=request.user.sub,
client_id=client.get_client_id(),
redirect_uri=request.redirect_uri, # pyright: ignore[reportDeprecated]
scope=request.scope, # pyright: ignore[reportDeprecated]
nonce=request.data.get("nonce"), # pyright: ignore[reportDeprecated]
)
)
class OpenIDCode(authlib.oidc.core.OpenIDCode):
def __init__(self, require_nonce: bool, token_max_age: timedelta):
super().__init__(require_nonce)
self._token_max_mage = token_max_age
@override
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
return storage.exists_nonce(nonce)
if _authlib_version >= (1, 7):
def resolve_client_private_key(self, client: object):
return joserfc.jwk.KeySet([storage.jwk])
def get_client_claims(self, client: object):
return {
"iss": flask.request.host_url.rstrip("/"),
"exp": int((datetime.now(UTC) + self._token_max_mage).timestamp()),
}
else:
@override
def get_jwt_config( # pyright: ignore[reportIncompatibleMethodOverride]
self, grant: authlib.oauth2.rfc6749.BaseGrant, client: object = None
):
return {
"key": storage.jwk.as_dict(is_private=True),
"alg": _JWS_ALG,
"exp": int(self._token_max_mage.total_seconds()),
"iss": flask.request.host_url.rstrip("/"),
}
@override
def generate_user_info(self, user: User, scope: str): # pyright: ignore[reportIncompatibleMethodOverride]
return _user_claims_for_scope(user, scope)
class RefreshTokenGrant(authlib.oauth2.rfc6749.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
@override
def authenticate_refresh_token(self, refresh_token: str):
token = storage.get_refresh_token(refresh_token)
if not token:
raise authlib.oauth2.rfc6749.InvalidGrantError("invalid refresh token")
return token
def authenticate_user(self, refresh_token: RefreshToken):
return storage.get_user(refresh_token.user_id)
def revoke_old_credential(self, refresh_token: authlib.oauth2.rfc6749.TokenMixin):
assert isinstance(refresh_token, RefreshToken)
storage.remove_access_token(refresh_token.access_token)
def _user_claims_for_scope(user: User, scope: str) -> dict[str, object]:
scopes = scope.split(" ")
allowed_standard_claims_for_scope = {
claim for scope in scopes for claim in _SCOPES_TO_CLAIMS.get(scope, [])
}
return {
**{
name: value
for name, value in user.claims.items()
if name not in _STANDARD_CLAIMS or name in allowed_standard_claims_for_scope
},
"sub": user.sub,
}
# https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims
_SCOPES_TO_CLAIMS: dict[str, Sequence[str]] = {
"profile": [
"name",
"family_name",
"given_name",
"middle_name",
"nickname",
"preferred_username",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"updated_at",
],
"email": ["email", "email_verified"],
"address": ["address"],
"phone": ["phone_number", "phone_number_verified"],
}
_STANDARD_CLAIMS = {claim for claims in _SCOPES_TO_CLAIMS.values() for claim in claims}
require_oauth = flask_oauth2.ResourceProtector()
authorization = cast(
"flask_oauth2.AuthorizationServer",
werkzeug.local.LocalProxy(lambda: flask.g._authlib_authorization_server),
)
blueprint = flask.Blueprint("oidc-provider-mock", __name__)
@blueprint.after_request
def add_cors_headers(response: flask.Response) -> flask.Response:
if flask.request.endpoint == f"{blueprint.name}.{authorize.__name__}":
return response
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, OPTIONS"
return response
@dataclass(kw_only=True, frozen=True)
class Config:
require_client_registration: bool = False
require_nonce: bool = False
issue_refresh_token: bool = True
access_token_max_age: timedelta = timedelta(hours=1)
user_claims: Sequence[User] = ()
@blueprint.record
def setup(setup_state: flask.blueprints.BlueprintSetupState):
assert isinstance(setup_state.app, flask.Flask)
config = setup_state.options["config"]
assert isinstance(config, Config)
setup_state.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
"authorization_code": int(config.access_token_max_age.total_seconds()),
}
setup_state.app.config["OAUTH2_REFRESH_TOKEN_GENERATOR"] = (
config.issue_refresh_token
)
authorization = flask_oauth2.AuthorizationServer()
storage = Storage()
for user in config.user_claims:
storage.store_user(user)
@setup_state.app.before_request
def set_globals():
flask.g.oidc_provider_mock_storage = storage
flask.g.oidc_provider_mock_config = config
flask.g._authlib_authorization_server = authorization
def query_client(id: str) -> Client | None:
client = storage.get_client(id)
if not client and not config.require_client_registration:
client = Client(
id=id,
secret=ClientAllowAny(),
redirect_uris=ClientAllowAny(),
allowed_scopes=Client.SCOPES_SUPPORTED,
token_endpoint_auth_method=ClientAllowAny(),
)
return client
def save_token(token: dict[str, object], request: OAuth2Request):
assert token["token_type"] == "Bearer"
assert isinstance(token["access_token"], str)
assert isinstance(token["expires_in"], int)
assert isinstance(request.user, User)
scope = token.get("scope", "")
assert isinstance(scope, str)
storage.store_access_token(
AccessToken(
token=token["access_token"],
user_id=request.user.sub,
# request.scope may actually be None
scope=scope,
expires_at=datetime.now(UTC) + timedelta(seconds=token["expires_in"]),
)
)
if "refresh_token" in token:
assert isinstance(token["refresh_token"], str)
assert isinstance(request.client, Client)
storage.store_refresh_token(
RefreshToken(
access_token=token["access_token"],
token=token["refresh_token"],
user_id=request.user.sub,
scope=scope,
expires_at=datetime.now(UTC)
+ timedelta(seconds=token["expires_in"]),
client_id=request.client.id,
)
)
authorization.init_app( # type: ignore
setup_state.app,
query_client=query_client,
save_token=save_token,
)
authorization.register_grant(
AuthorizationCodeGrant,
[
OpenIDCode(
require_nonce=config.require_nonce,
token_max_age=config.access_token_max_age,
)
],
)
authorization.register_grant(RefreshTokenGrant)
@blueprint.record_once
def setup_once(setup_state: flask.blueprints.BlueprintSetupState):
require_oauth.register_token_validator(TokenValidator()) # pyright: ignore[reportUnknownMemberType]
[docs]
def app(
*,
require_client_registration: bool = False,
require_nonce: bool = False,
issue_refresh_token: bool = True,
access_token_max_age: timedelta = timedelta(hours=1),
user_claims: Sequence[User] = (),
) -> flask.Flask:
"""Create a Flask app running the OpenID provider.
Call ``app().run()`` (see `flask.Flask.run`) to start the server.
See ``init_app`` for documentation of parameters
"""
app = flask.Flask(__name__)
init_app(
app,
require_client_registration=require_client_registration,
require_nonce=require_nonce,
issue_refresh_token=issue_refresh_token,
access_token_max_age=access_token_max_age,
user_claims=user_claims,
)
app.secret_key = secrets.token_bytes(16)
if isinstance(app.json, flask.json.provider.DefaultJSONProvider):
# Make it easier to debug responses
app.json.compact = False
return app
[docs]
def init_app(
app: flask.Flask,
*,
require_client_registration: bool = False,
require_nonce: bool = False,
issue_refresh_token: bool = True,
access_token_max_age: timedelta = timedelta(hours=1),
user_claims: Sequence[User] = (),
):
"""Add the OpenID provider and its endpoints to the flask ``app``.
:param require_client_registration: If false (the default) any client ID and
secret can be used to authenticate with the token endpoint. If true,
clients have to be registered using the `OAuth 2.0 Dynamic Client
Registration Protocol <https://datatracker.ietf.org/doc/html/rfc7591>`_.
:param require_nonce: If true, the authorization request must include the
`nonce parameter`_ to prevent replay attacks. If the parameter is not
provided the authorization request will fail.
:param issue_refresh_token: If true (the default), the token endpoint response
will include a refresh token.
:param access_token_max_age: Max age of access and ID token after which it expires.
:param user_claims: Predefined users that can be authorized with one click.
.. _nonce parameter: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
"""
app.register_blueprint(
blueprint,
config=Config(
require_client_registration=require_client_registration,
require_nonce=require_nonce,
issue_refresh_token=issue_refresh_token,
access_token_max_age=access_token_max_age,
user_claims=user_claims,
),
)
app.register_blueprint(_client.blueprint)
app.debug = True
app.wsgi_app = ProxyFix(app.wsgi_app, x_host=1, x_proto=1, x_port=1)
return app
@blueprint.get("/")
def home():
return flask.render_template("index.html")
@blueprint.get("/.well-known/openid-configuration")
def openid_config():
def url_for(fn: Callable[..., object]) -> str:
return urljoin(
flask.request.host_url,
flask.url_for(f".{fn.__name__}"),
)
# See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
# for information about the fields.
return flask.jsonify({
"issuer": flask.request.host_url.rstrip("/"),
"authorization_endpoint": url_for(authorize),
"token_endpoint": url_for(issue_token),
"userinfo_endpoint": url_for(userinfo),
"registration_endpoint": url_for(register_client),
"end_session_endpoint": url_for(end_session),
"jwks_uri": url_for(jwks),
"response_types_supported": Client.RESPONSE_TYPES_SUPPORTED,
"response_modes_supported": ["query"],
"grant_types_supported": Client.GRANT_TYPES_SUPPORTED,
"scopes_supported": Client.SCOPES_SUPPORTED,
"id_token_signing_alg_values_supported": [_JWS_ALG],
"subject_types_supported": ["public"],
})
@blueprint.get("/jwks")
def jwks():
return flask.jsonify(joserfc.jwk.KeySet([storage.jwk]).as_dict(private=False))
class RegisterClientBody(pydantic.BaseModel):
redirect_uris: Sequence[pydantic.HttpUrl]
token_endpoint_auth_method: ClientAuthMethod = "client_secret_basic"
scope: str | None = None
@blueprint.post("/oauth2/clients")
def register_client():
body = _validate_body(flask.request, RegisterClientBody)
is_public = body.token_endpoint_auth_method == "none"
secret = "" if is_public else secrets.token_urlsafe(16)
client = Client(
id=str(uuid4()),
secret=secret,
redirect_uris=[str(uri) for uri in body.redirect_uris],
allowed_scopes=body.scope or Client.SCOPES_SUPPORTED,
token_endpoint_auth_method=body.token_endpoint_auth_method,
)
storage.store_client(client)
response: dict[str, object] = {
"client_id": client.id,
"redirect_uris": client.redirect_uris,
"token_endpoint_auth_method": body.token_endpoint_auth_method,
"grant_types": Client.GRANT_TYPES_SUPPORTED,
"response_types": Client.RESPONSE_TYPES_SUPPORTED,
}
if not is_public:
response["client_secret"] = secret
return flask.jsonify(response), HTTPStatus.CREATED
@blueprint.route("/oauth2/authorize", methods=["GET", "POST"])
def authorize() -> flask.typing.ResponseReturnValue:
request = FlaskOAuth2Request(flask.request)
try:
grant, redirect_uri = _validate_auth_request_client_params(flask.request)
assert isinstance(grant.client, Client) # pyright: ignore[reportUnknownMemberType]
except _AuthorizationValidationException as exc:
_logger.warning(f"invalid authorization request: {exc.description}")
raise
config = flask.g.oidc_provider_mock_config
assert isinstance(config, Config)
predefined_users = [user.sub for user in config.user_claims]
recent_subjects = [
sub for sub in storage.get_recent_subjects() if sub not in predefined_users
]
scopes = flask.request.args.get("scope", "").split()
if flask.request.method == "GET":
return flask.render_template(
"authorization_form.html",
redirect_uri=redirect_uri,
client_id=grant.client.id,
scopes=scopes,
recent_subjects=recent_subjects,
predefined_users=predefined_users,
)
else:
if flask.request.form.get("action") == "deny":
return authorization.handle_response( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
*authlib.oauth2.rfc6749.AccessDeniedError( # pyright: ignore[reportUnknownArgumentType]
redirect_uri=flask.request.args["redirect_uri"]
)()
)
sub = flask.request.form.get("sub")
if not sub:
raise _AuthorizationValidationException(
authlib.oauth2.rfc6749.InvalidRequestError.error,
"Missing 'sub' form parameter",
)
user = storage.get_user(sub)
if not user:
user = User(sub=sub, claims={"email": sub})
storage.store_user(user)
try:
response = grant.create_authorization_response(redirect_uri, user) # pyright: ignore
_logger.info(
"issued authorization code",
extra=({"client": grant.client, "user": user}),
)
storage.record_subject(sub)
return authorization.handle_response(*response) # pyright: ignore
except authlib.oauth2.OAuth2Error as error:
_logger.warning("invalid authorization request", exc_info=True)
return authorization.handle_error_response(request, error) # pyright: ignore
def _validate_auth_request_client_params(
flask_request: flask.Request,
) -> tuple[AuthorizationCodeGrant | RefreshTokenGrant, str]:
"""Validate query parameters sent by the client to the authorization endpoint.
Raises ``_AuthorizationValidationException`` if validation fails which results
in an appropriate 400 response.
"""
request = FlaskOAuth2Request(flask_request)
try:
grant = authorization.get_consent_grant() # type: ignore
assert isinstance(grant, AuthorizationCodeGrant)
redirect_uri = grant.validate_authorization_request()
except authlib.oauth2.rfc6749.InvalidClientError as e:
raise _AuthorizationValidationException(
authlib.oauth2.rfc6749.InvalidClientError.error,
e.description,
) from e
except authlib.oauth2.rfc6749.UnsupportedResponseTypeError as e:
raise _AuthorizationValidationException(
e.error,
f"OAuth response_type {e.response_type} is not supported",
) from e
except authlib.oauth2.rfc6749.InvalidRequestError as e:
description = e.description
# FIXME: this is a brittle way of determining what the error is but
# authlib does not raise a dedicated error in this case.
if description == "Redirect URI foo is not supported by client.":
raise _AuthorizationValidationException(
authlib.oauth2.rfc6749.InvalidClientError.error,
description,
) from e
else:
raise werkzeug.exceptions.HTTPException(
response=flask.make_response(
authorization.handle_error_response(request, e)
)
) from e
except authlib.oauth2.OAuth2Error as e:
raise werkzeug.exceptions.HTTPException(
response=flask.make_response(
authorization.handle_error_response(request, e)
)
) from e
return grant, redirect_uri
class _AuthorizationValidationException(werkzeug.exceptions.HTTPException):
error: str
def __init__(self, error: str, description: str):
self.error = error
self.description = description
response = flask.make_response(
flask.render_template("error.html", name=error, description=description),
HTTPStatus.BAD_REQUEST,
)
super().__init__(response=response)
self.code = HTTPStatus.BAD_REQUEST
@blueprint.post("/oauth2/token")
def issue_token() -> flask.typing.ResponseReturnValue:
request = FlaskOAuth2Request(flask.request)
try:
grant = authorization.get_token_grant(request)
except authlib.oauth2.rfc6749.UnsupportedGrantTypeError as error:
_logger.warning(
"unsupported grant type for issuing token",
extra={"grant_type": error.grant_type},
)
return authorization.handle_error_response(request, error) # type: ignore
assert isinstance(grant, AuthorizationCodeGrant | RefreshTokenGrant)
try:
grant.validate_token_request()
args = grant.create_token_response()
return authorization.handle_response(*args) # type: ignore
except OAuth2Error as error:
if error.error:
_logger.warning(
f"token endpoint error {error.error}",
extra={"description": error.description},
)
else:
_logger.warning("error while issuing token", exc_info=error)
return authorization.handle_error_response(request, error) # type: ignore
@blueprint.route("/userinfo", methods=["GET", "POST"])
@require_oauth() # pyright: ignore[reportUntypedFunctionDecorator]
def userinfo():
access_token = flask_oauth2.current_token
assert isinstance(access_token, AccessToken)
return flask.jsonify(
_user_claims_for_scope(access_token.get_user(), access_token.scope)
)
SetUserBody = pydantic.RootModel[dict[str, object]]
@blueprint.put("/users/<sub>")
def set_user(sub: str):
body = _validate_body(flask.request, SetUserBody)
storage.store_user(User(sub=sub, claims=body.root))
return "", HTTPStatus.NO_CONTENT
@blueprint.post("/users/<sub>/revoke-tokens")
def revoke_user_tokens(sub: str):
for access_token in storage.access_tokens():
if access_token.user_id == sub:
storage.remove_access_token(access_token.token)
for refresh_token in storage.refresh_tokens():
if refresh_token.user_id == sub:
storage.remove_refresh_token(refresh_token.token)
return "", HTTPStatus.NO_CONTENT
@blueprint.route("/oauth2/end_session", methods=["GET", "POST"])
def end_session() -> flask.typing.ResponseReturnValue:
# https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
id_token_hint = flask.request.values.get("id_token_hint")
post_logout_redirect_uri = flask.request.values.get("post_logout_redirect_uri")
state = flask.request.values.get("state")
# Not handled: client_id, logout_hint and ui_locales
request_parameters = flask.request.values
# Add any state value to the redirect URI
if post_logout_redirect_uri is not None and state is not None:
redirect_uri_parsed = urlparse(post_logout_redirect_uri)
query = parse_qs(redirect_uri_parsed.query, keep_blank_values=True)
query["state"] = [state]
redirect_uri = urlunparse(
redirect_uri_parsed._replace(query=urlencode(query, doseq=True))
)
else:
redirect_uri = post_logout_redirect_uri
return flask.render_template(
"end_session_form.html",
id_token_hint=id_token_hint,
redirect_uri=redirect_uri,
request_parameters=request_parameters,
end_session_confirm_url=flask.url_for(f".{end_session_confirm.__name__}"),
)
@blueprint.route("/oauth2/end_session/confirm", methods=["POST"])
def end_session_confirm() -> flask.typing.ResponseReturnValue:
redirect_uri = flask.request.form.get("redirect_uri")
if redirect_uri is not None:
return flask.redirect(redirect_uri)
else:
return flask.render_template(
"end_session_confirm.html",
session_ended=True,
)
class InsecureTransportError(Exception):
def __init__(self):
super().__init__(
"OAuth 2 requires https. Set the environment variable"
"`AUTHLIB_INSECURE_TRANSPORT=1` to disable this check"
)
def _insecure_transport_error_handler(
error: authlib.oauth2.rfc6749.errors.InsecureTransportError,
) -> Never:
raise InsecureTransportError() from error
blueprint.register_error_handler(
authlib.oauth2.rfc6749.errors.InsecureTransportError,
_insecure_transport_error_handler,
)
def _validate_body[Model: pydantic.BaseModel](
request: flask.Request, model: type[Model]
) -> Model:
try:
return model.model_validate(request.json, strict=True)
except pydantic.ValidationError as error:
_logger.info(
f"invalid request body {request.method} {request.url}\n{textwrap.indent(str(error), ' ')}",
extra={
"_msg": "invalid request body",
"method": request.method,
"url": request.url,
"error": error,
},
)
# TODO: support content type negotiation with html and json
msg = "Invalid body:\n"
for detail in error.errors():
loc = detail.get("loc")
if loc:
msg += f"- {_pydantic_loc_to_path(loc)}:"
msg += f" {detail.get('msg')}\n"
raise werkzeug.exceptions.HTTPException(
response=flask.make_response(
msg,
HTTPStatus.BAD_REQUEST,
{"content-type": "text/plain; charset=utf-8"},
)
) from error
def _pydantic_loc_to_path(loc: tuple[str | int, ...]) -> str:
path = ""
for i, x in enumerate(loc):
match x:
case str():
if i > 0:
path += "."
path += x
case int():
path += f"[{x}]"
return path