""" Authentication and security utilities. Provides password hashing, token encryption, and session management. """ from datetime import datetime, timedelta from typing import Optional from fastapi import Depends, HTTPException, status, Request, Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from passlib.hash import bcrypt from cryptography.fernet import Fernet from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired import base64 import os from app.models import User from app.database import get_db from app.config import get_settings # Session configuration SESSION_COOKIE_NAME = "session" SESSION_MAX_AGE = 60 * 60 * 24 * 30 # 30 days def get_password_hash(password: str) -> str: """Hash a password using bcrypt.""" return bcrypt.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return bcrypt.verify(plain_password, hashed_password) def get_fernet_key() -> bytes: """ Get or generate Fernet encryption key for API tokens. Uses SECRET_KEY from settings to derive a consistent encryption key. """ settings = get_settings() # Derive a 32-byte key from SECRET_KEY key = base64.urlsafe_b64encode(settings.secret_key.encode().ljust(32)[:32]) return key def encrypt_token(token: str) -> str: """Encrypt an API token using Fernet.""" fernet = Fernet(get_fernet_key()) return fernet.encrypt(token.encode()).decode() def decrypt_token(encrypted_token: str) -> str: """Decrypt an API token using Fernet.""" fernet = Fernet(get_fernet_key()) return fernet.decrypt(encrypted_token.encode()).decode() def get_serializer() -> URLSafeTimedSerializer: """Get session serializer.""" settings = get_settings() return URLSafeTimedSerializer(settings.secret_key) def create_session_token(user_id: int) -> str: """Create a signed session token for a user.""" serializer = get_serializer() return serializer.dumps({"user_id": user_id}) def verify_session_token(token: str, max_age: int = SESSION_MAX_AGE) -> Optional[int]: """ Verify a session token and return the user_id. Returns None if token is invalid or expired. """ serializer = get_serializer() try: data = serializer.loads(token, max_age=max_age) return data.get("user_id") except (BadSignature, SignatureExpired): return None def set_session_cookie(response: Response, user_id: int): """Set session cookie on response.""" token = create_session_token(user_id) response.set_cookie( key=SESSION_COOKIE_NAME, value=token, max_age=SESSION_MAX_AGE, httponly=True, samesite="lax", # Set secure=True in production with HTTPS secure=False ) def clear_session_cookie(response: Response): """Clear session cookie.""" response.delete_cookie(key=SESSION_COOKIE_NAME) async def get_current_user( request: Request, db: AsyncSession = Depends(get_db) ) -> User: """ Get the current authenticated user from session cookie. Raises 401 Unauthorized if not authenticated. """ # Get session token from cookie token = request.cookies.get(SESSION_COOKIE_NAME) if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" ) # Verify token and get user_id user_id = verify_session_token(token) if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired session" ) # Get user from database result = await db.execute( select(User).where(User.id == user_id, User.is_active == True) ) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive" ) return user async def get_current_user_optional( request: Request, db: AsyncSession = Depends(get_db) ) -> Optional[User]: """ Get the current authenticated user from session cookie. Returns None if not authenticated (does not raise exception). """ try: return await get_current_user(request, db) except HTTPException: return None async def authenticate_user( db: AsyncSession, username: str, password: str ) -> Optional[User]: """ Authenticate a user by username and password. Returns User if authentication succeeds, None otherwise. """ # Find user by username result = await db.execute( select(User).where(User.username == username) ) user = result.scalar_one_or_none() if not user: return None # Verify password if not verify_password(password, user.hashed_password): return None # Check if user is active if not user.is_active: return None # Update last login user.last_login = datetime.now() await db.commit() return user async def create_user( db: AsyncSession, username: str, email: str, password: str, abs_url: str, abs_api_token: str, display_name: Optional[str] = None ) -> User: """ Create a new user account. Raises HTTPException if username or email already exists. """ # Check if username already exists result = await db.execute( select(User).where(User.username == username) ) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) # Check if email already exists result = await db.execute( select(User).where(User.email == email) ) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Create new user user = User( username=username, email=email, hashed_password=get_password_hash(password), abs_url=abs_url, abs_api_token=encrypt_token(abs_api_token), display_name=display_name or username, created_at=datetime.now(), is_active=True ) db.add(user) await db.commit() await db.refresh(user) return user