auth.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """
  2. Authentication and security utilities.
  3. Provides password hashing, token encryption, and session management.
  4. """
  5. from datetime import datetime, timedelta
  6. from typing import Optional
  7. from fastapi import Depends, HTTPException, status, Request, Response
  8. from sqlalchemy.ext.asyncio import AsyncSession
  9. from sqlalchemy import select
  10. from passlib.hash import bcrypt
  11. from cryptography.fernet import Fernet
  12. from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
  13. import base64
  14. import os
  15. from app.models import User
  16. from app.database import get_db
  17. from app.config import get_settings
  18. # Session configuration
  19. SESSION_COOKIE_NAME = "session"
  20. SESSION_MAX_AGE = 60 * 60 * 24 * 30 # 30 days
  21. def get_password_hash(password: str) -> str:
  22. """Hash a password using bcrypt."""
  23. return bcrypt.hash(password)
  24. def verify_password(plain_password: str, hashed_password: str) -> bool:
  25. """Verify a password against its hash."""
  26. return bcrypt.verify(plain_password, hashed_password)
  27. def get_fernet_key() -> bytes:
  28. """
  29. Get or generate Fernet encryption key for API tokens.
  30. Uses SECRET_KEY from settings to derive a consistent encryption key.
  31. """
  32. settings = get_settings()
  33. # Derive a 32-byte key from SECRET_KEY
  34. key = base64.urlsafe_b64encode(settings.secret_key.encode().ljust(32)[:32])
  35. return key
  36. def encrypt_token(token: str) -> str:
  37. """Encrypt an API token using Fernet."""
  38. fernet = Fernet(get_fernet_key())
  39. return fernet.encrypt(token.encode()).decode()
  40. def decrypt_token(encrypted_token: str) -> str:
  41. """Decrypt an API token using Fernet."""
  42. fernet = Fernet(get_fernet_key())
  43. return fernet.decrypt(encrypted_token.encode()).decode()
  44. def get_serializer() -> URLSafeTimedSerializer:
  45. """Get session serializer."""
  46. settings = get_settings()
  47. return URLSafeTimedSerializer(settings.secret_key)
  48. def create_session_token(user_id: int) -> str:
  49. """Create a signed session token for a user."""
  50. serializer = get_serializer()
  51. return serializer.dumps({"user_id": user_id})
  52. def verify_session_token(token: str, max_age: int = SESSION_MAX_AGE) -> Optional[int]:
  53. """
  54. Verify a session token and return the user_id.
  55. Returns None if token is invalid or expired.
  56. """
  57. serializer = get_serializer()
  58. try:
  59. data = serializer.loads(token, max_age=max_age)
  60. return data.get("user_id")
  61. except (BadSignature, SignatureExpired):
  62. return None
  63. def set_session_cookie(response: Response, user_id: int):
  64. """Set session cookie on response."""
  65. token = create_session_token(user_id)
  66. response.set_cookie(
  67. key=SESSION_COOKIE_NAME,
  68. value=token,
  69. max_age=SESSION_MAX_AGE,
  70. httponly=True,
  71. samesite="lax",
  72. # Set secure=True in production with HTTPS
  73. secure=False
  74. )
  75. def clear_session_cookie(response: Response):
  76. """Clear session cookie."""
  77. response.delete_cookie(key=SESSION_COOKIE_NAME)
  78. async def get_current_user(
  79. request: Request,
  80. db: AsyncSession = Depends(get_db)
  81. ) -> User:
  82. """
  83. Get the current authenticated user from session cookie.
  84. Raises 401 Unauthorized if not authenticated.
  85. """
  86. # Get session token from cookie
  87. token = request.cookies.get(SESSION_COOKIE_NAME)
  88. if not token:
  89. raise HTTPException(
  90. status_code=status.HTTP_401_UNAUTHORIZED,
  91. detail="Not authenticated"
  92. )
  93. # Verify token and get user_id
  94. user_id = verify_session_token(token)
  95. if user_id is None:
  96. raise HTTPException(
  97. status_code=status.HTTP_401_UNAUTHORIZED,
  98. detail="Invalid or expired session"
  99. )
  100. # Get user from database
  101. result = await db.execute(
  102. select(User).where(User.id == user_id, User.is_active == True)
  103. )
  104. user = result.scalar_one_or_none()
  105. if not user:
  106. raise HTTPException(
  107. status_code=status.HTTP_401_UNAUTHORIZED,
  108. detail="User not found or inactive"
  109. )
  110. return user
  111. async def get_current_user_optional(
  112. request: Request,
  113. db: AsyncSession = Depends(get_db)
  114. ) -> Optional[User]:
  115. """
  116. Get the current authenticated user from session cookie.
  117. Returns None if not authenticated (does not raise exception).
  118. """
  119. try:
  120. return await get_current_user(request, db)
  121. except HTTPException:
  122. return None
  123. async def authenticate_user(
  124. db: AsyncSession,
  125. username: str,
  126. password: str
  127. ) -> Optional[User]:
  128. """
  129. Authenticate a user by username and password.
  130. Returns User if authentication succeeds, None otherwise.
  131. """
  132. # Find user by username
  133. result = await db.execute(
  134. select(User).where(User.username == username)
  135. )
  136. user = result.scalar_one_or_none()
  137. if not user:
  138. return None
  139. # Verify password
  140. if not verify_password(password, user.hashed_password):
  141. return None
  142. # Check if user is active
  143. if not user.is_active:
  144. return None
  145. # Update last login
  146. user.last_login = datetime.now()
  147. await db.commit()
  148. return user
  149. async def create_user(
  150. db: AsyncSession,
  151. username: str,
  152. email: str,
  153. password: str,
  154. abs_url: str,
  155. abs_api_token: str,
  156. display_name: Optional[str] = None
  157. ) -> User:
  158. """
  159. Create a new user account.
  160. Raises HTTPException if username or email already exists.
  161. """
  162. # Check if username already exists
  163. result = await db.execute(
  164. select(User).where(User.username == username)
  165. )
  166. if result.scalar_one_or_none():
  167. raise HTTPException(
  168. status_code=status.HTTP_400_BAD_REQUEST,
  169. detail="Username already registered"
  170. )
  171. # Check if email already exists
  172. result = await db.execute(
  173. select(User).where(User.email == email)
  174. )
  175. if result.scalar_one_or_none():
  176. raise HTTPException(
  177. status_code=status.HTTP_400_BAD_REQUEST,
  178. detail="Email already registered"
  179. )
  180. # Create new user
  181. user = User(
  182. username=username,
  183. email=email,
  184. hashed_password=get_password_hash(password),
  185. abs_url=abs_url,
  186. abs_api_token=encrypt_token(abs_api_token),
  187. display_name=display_name or username,
  188. created_at=datetime.now(),
  189. is_active=True
  190. )
  191. db.add(user)
  192. await db.commit()
  193. await db.refresh(user)
  194. return user