Refactor SSO authentication to support multiple providers and enhance error handling
This commit is contained in:
parent
862518a45f
commit
c4cef0d9b5
10 changed files with 233 additions and 66 deletions
16
auth/providers/__init__.py
Normal file
16
auth/providers/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from auth.providers.base import SSOProvider
|
||||
from auth.providers.kit import KITProvider
|
||||
|
||||
# Registry of available SSO providers
|
||||
PROVIDERS: dict[str, type[SSOProvider]] = {
|
||||
"kit": KITProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_provider(name: str) -> type[SSOProvider]:
|
||||
"""Get an SSO provider class by name."""
|
||||
provider = PROVIDERS.get(name.lower())
|
||||
if not provider:
|
||||
available = ", ".join(PROVIDERS.keys())
|
||||
raise ValueError(f"Unknown SSO provider: {name}. Available: {available}")
|
||||
return provider
|
||||
35
auth/providers/base.py
Normal file
35
auth/providers/base.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import requests
|
||||
|
||||
|
||||
class SSOProvider(ABC):
|
||||
"""Base class for SSO authentication providers."""
|
||||
|
||||
# Override these in subclasses
|
||||
name: str = "base"
|
||||
domain: str = ""
|
||||
|
||||
def __init__(self, username: str, password: str):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.session: requests.Session = None
|
||||
self.redirect_response: requests.Response = None
|
||||
self.saml_response_html: str = None
|
||||
|
||||
def set_session(self, session: requests.Session):
|
||||
"""Set the shared session from AnnySession."""
|
||||
self.session = session
|
||||
|
||||
def set_redirect_response(self, response: requests.Response):
|
||||
"""Set the redirect response from Anny SSO initiation."""
|
||||
self.redirect_response = response
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self) -> str:
|
||||
"""
|
||||
Perform institution-specific authentication.
|
||||
|
||||
Returns:
|
||||
The HTML containing the SAML response, or raises an exception on failure.
|
||||
"""
|
||||
pass
|
||||
36
auth/providers/kit.py
Normal file
36
auth/providers/kit.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import html
|
||||
from auth.providers.base import SSOProvider
|
||||
from utils.helpers import extract_html_value
|
||||
|
||||
|
||||
class KITProvider(SSOProvider):
|
||||
"""SSO provider for Karlsruhe Institute of Technology (KIT)."""
|
||||
|
||||
name = "KIT"
|
||||
domain = "kit.edu"
|
||||
|
||||
def authenticate(self) -> str:
|
||||
self.session.headers.pop('x-requested-with', None)
|
||||
self.session.headers.pop('x-inertia', None)
|
||||
self.session.headers.pop('x-inertia-version', None)
|
||||
|
||||
csrf_token = extract_html_value(
|
||||
self.redirect_response.text,
|
||||
r'name="csrf_token" value="([^"]+)"'
|
||||
)
|
||||
|
||||
response = self.session.post(
|
||||
'https://idp.scc.kit.edu/idp/profile/SAML2/Redirect/SSO?execution=e1s1',
|
||||
data={
|
||||
'csrf_token': csrf_token,
|
||||
'j_username': self.username,
|
||||
'j_password': self.password,
|
||||
'_eventId_proceed': '',
|
||||
'fudis_web_authn_assertion_input': '',
|
||||
}
|
||||
)
|
||||
|
||||
if "/consume" not in html.unescape(response.text):
|
||||
raise ValueError("KIT authentication failed - invalid credentials or SSO error")
|
||||
|
||||
return response.text
|
||||
|
|
@ -1,27 +1,38 @@
|
|||
import requests
|
||||
import urllib.parse
|
||||
import re
|
||||
import html
|
||||
from config.constants import AUTH_BASE_URL, ANNY_BASE_URL, DEFAULT_HEADERS
|
||||
from utils.helpers import extract_html_value
|
||||
from auth.providers import get_provider, SSOProvider
|
||||
|
||||
|
||||
class AnnySession:
|
||||
def __init__(self, username, password):
|
||||
def __init__(self, username: str, password: str, provider_name: str = "kit"):
|
||||
self.session = requests.Session()
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
# Initialize the SSO provider
|
||||
provider_class = get_provider(provider_name)
|
||||
self.provider: SSOProvider = provider_class(username, password)
|
||||
|
||||
def login(self):
|
||||
try:
|
||||
self._init_headers()
|
||||
self._sso_login()
|
||||
self._kit_auth()
|
||||
self._provider_auth()
|
||||
self._consume_saml()
|
||||
print("✅ Login successful.")
|
||||
print(f"✅ Login successful via {self.provider.name}.")
|
||||
return self.session.cookies
|
||||
except Exception as e:
|
||||
except requests.RequestException as e:
|
||||
print(f"[Login Error] Network error: {type(e).__name__}")
|
||||
return None
|
||||
except ValueError as e:
|
||||
print(f"[Login Error] {e}")
|
||||
return None
|
||||
except KeyError as e:
|
||||
print(f"[Login Error] Missing expected field: {e}")
|
||||
return None
|
||||
|
||||
def _init_headers(self):
|
||||
self.session.headers.update({
|
||||
|
|
@ -45,32 +56,17 @@ class AnnySession:
|
|||
'x-inertia-version': x_inertia_version
|
||||
})
|
||||
|
||||
r2 = self.session.post(f"{AUTH_BASE_URL}/login/sso", json={"domain": "kit.edu"})
|
||||
r2 = self.session.post(f"{AUTH_BASE_URL}/login/sso", json={"domain": self.provider.domain})
|
||||
redirect_url = r2.headers['x-inertia-location']
|
||||
self.redirect_response = self.session.get(redirect_url)
|
||||
redirect_response = self.session.get(redirect_url)
|
||||
|
||||
def _kit_auth(self):
|
||||
self.session.headers.pop('x-requested-with', None)
|
||||
self.session.headers.pop('x-inertia', None)
|
||||
self.session.headers.pop('x-inertia-version', None)
|
||||
# Pass session and redirect response to provider
|
||||
self.provider.set_session(self.session)
|
||||
self.provider.set_redirect_response(redirect_response)
|
||||
|
||||
csrf_token = extract_html_value(self.redirect_response.text, r'name="csrf_token" value="([^"]+)"')
|
||||
|
||||
r4 = self.session.post(
|
||||
'https://idp.scc.kit.edu/idp/profile/SAML2/Redirect/SSO?execution=e1s1',
|
||||
data={
|
||||
'csrf_token': csrf_token,
|
||||
'j_username': self.username,
|
||||
'j_password': self.password,
|
||||
'_eventId_proceed': '',
|
||||
'fudis_web_authn_assertion_input': '',
|
||||
}
|
||||
)
|
||||
|
||||
if "/consume" not in html.unescape(r4.text):
|
||||
raise Exception("KIT authentication failed")
|
||||
|
||||
self.saml_response_html = r4.text
|
||||
def _provider_auth(self):
|
||||
"""Delegate authentication to the SSO provider."""
|
||||
self.saml_response_html = self.provider.authenticate()
|
||||
|
||||
def _consume_saml(self):
|
||||
consume_url = extract_html_value(self.saml_response_html, r'form action="([^"]+)"')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue