Source code for fmn.api.auth
# SPDX-FileCopyrightText: Contributors to the Fedora Project
#
# SPDX-License-Identifier: MIT
import logging
import time
from typing import Any
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from httpx import AsyncClient
from pydantic import BaseModel
from ..core.config import get_settings
log = logging.getLogger(__name__)
[docs]class TokenExpired(ValueError):
pass
[docs]class Identity(BaseModel):
_client = None
_token_to_identities_cache = {}
_cache_next_gc_after = None
name: str
admin: bool
expires_at: float
user_info: dict[str, Any]
[docs] class Config:
extra = "ignore"
[docs] @classmethod
def client(cls) -> AsyncClient:
settings = get_settings()
if not cls._client:
cls._client = AsyncClient(
base_url=settings.oidc_provider_url,
timeout=None,
)
return cls._client
@classmethod
def _cache_collect_garbage(cls, force: bool = False) -> None:
id_cache_gc_interval = get_settings().id_cache_gc_interval
now = time.time()
then = now + id_cache_gc_interval
if not force:
if not cls._cache_next_gc_after:
cls._cache_next_gc_after = then
return
if now < cls._cache_next_gc_after:
return
cls._token_to_identities_cache = {
k: v for k, v in cls._token_to_identities_cache.items() if v.expires_at > now
}
cls._cache_next_gc_after = then
[docs] @classmethod
async def from_oidc_token(cls, token: str) -> "Identity":
identity = cls._token_to_identities_cache.get(token)
if not identity:
settings = get_settings()
token_info_response = await cls.client().post(
settings.oidc_token_info_url,
data={
"token": token,
"client_id": settings.oidc_client_id,
"client_secret": settings.oidc_client_secret,
},
)
token_info_response.raise_for_status()
token_info_result = token_info_response.json()
user_info_response = await cls.client().post(
settings.oidc_user_info_url, data={"access_token": token}
)
user_info_response.raise_for_status()
user_info_result = user_info_response.json()
identity = cls(
name=token_info_result["username"],
admin=any(
g in get_settings().admin_groups for g in user_info_result.get("groups", [])
),
expires_at=float(token_info_result["exp"]),
user_info=user_info_result,
)
if identity.expires_at < time.time():
cls._cache_collect_garbage(force=True)
raise TokenExpired(token)
else:
cls._cache_collect_garbage()
cls._token_to_identities_cache[token] = identity
return identity
[docs]class IdentityFactory:
def __init__(self, optional=False):
self.optional = optional
[docs] async def process_oidc_auth(
self, creds: HTTPAuthorizationCredentials | None
) -> Identity | None:
if not creds:
return None
return await Identity.from_oidc_token(creds.credentials)
async def __call__(
self, bearer: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
) -> Identity | None:
try:
identity = await self.process_oidc_auth(bearer)
except TokenExpired as exc:
if self.optional:
return None
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Token expired") from exc
if identity is None and not self.optional:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
return identity
get_identity = IdentityFactory(optional=False)
get_identity_optional = IdentityFactory(optional=True)
[docs]async def get_identity_admin(identity: Identity = Depends(get_identity)):
if not identity.admin:
raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Not an admin")
return identity