Refactor SSO authentication to support multiple providers and enhance error handling
This commit is contained in:
parent
862518a45f
commit
c4cef0d9b5
10 changed files with 233 additions and 66 deletions
16
auth/providers/__init__.py
Normal file
16
auth/providers/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from auth.providers.base import SSOProvider
|
||||
from auth.providers.kit import KITProvider
|
||||
|
||||
# Registry of available SSO providers
|
||||
PROVIDERS: dict[str, type[SSOProvider]] = {
|
||||
"kit": KITProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_provider(name: str) -> type[SSOProvider]:
|
||||
"""Get an SSO provider class by name."""
|
||||
provider = PROVIDERS.get(name.lower())
|
||||
if not provider:
|
||||
available = ", ".join(PROVIDERS.keys())
|
||||
raise ValueError(f"Unknown SSO provider: {name}. Available: {available}")
|
||||
return provider
|
||||
35
auth/providers/base.py
Normal file
35
auth/providers/base.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import requests
|
||||
|
||||
|
||||
class SSOProvider(ABC):
|
||||
"""Base class for SSO authentication providers."""
|
||||
|
||||
# Override these in subclasses
|
||||
name: str = "base"
|
||||
domain: str = ""
|
||||
|
||||
def __init__(self, username: str, password: str):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.session: requests.Session = None
|
||||
self.redirect_response: requests.Response = None
|
||||
self.saml_response_html: str = None
|
||||
|
||||
def set_session(self, session: requests.Session):
|
||||
"""Set the shared session from AnnySession."""
|
||||
self.session = session
|
||||
|
||||
def set_redirect_response(self, response: requests.Response):
|
||||
"""Set the redirect response from Anny SSO initiation."""
|
||||
self.redirect_response = response
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self) -> str:
|
||||
"""
|
||||
Perform institution-specific authentication.
|
||||
|
||||
Returns:
|
||||
The HTML containing the SAML response, or raises an exception on failure.
|
||||
"""
|
||||
pass
|
||||
36
auth/providers/kit.py
Normal file
36
auth/providers/kit.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import html
|
||||
from auth.providers.base import SSOProvider
|
||||
from utils.helpers import extract_html_value
|
||||
|
||||
|
||||
class KITProvider(SSOProvider):
|
||||
"""SSO provider for Karlsruhe Institute of Technology (KIT)."""
|
||||
|
||||
name = "KIT"
|
||||
domain = "kit.edu"
|
||||
|
||||
def authenticate(self) -> str:
|
||||
self.session.headers.pop('x-requested-with', None)
|
||||
self.session.headers.pop('x-inertia', None)
|
||||
self.session.headers.pop('x-inertia-version', None)
|
||||
|
||||
csrf_token = extract_html_value(
|
||||
self.redirect_response.text,
|
||||
r'name="csrf_token" value="([^"]+)"'
|
||||
)
|
||||
|
||||
response = self.session.post(
|
||||
'https://idp.scc.kit.edu/idp/profile/SAML2/Redirect/SSO?execution=e1s1',
|
||||
data={
|
||||
'csrf_token': csrf_token,
|
||||
'j_username': self.username,
|
||||
'j_password': self.password,
|
||||
'_eventId_proceed': '',
|
||||
'fudis_web_authn_assertion_input': '',
|
||||
}
|
||||
)
|
||||
|
||||
if "/consume" not in html.unescape(response.text):
|
||||
raise ValueError("KIT authentication failed - invalid credentials or SSO error")
|
||||
|
||||
return response.text
|
||||
Loading…
Add table
Add a link
Reference in a new issue