auth.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. Authentication and security utilities.
  3. Provides password hashing, token encryption, and session management.
  4. """
  5. from datetime import datetime, timedelta
  6. from typing import Optional
  7. from fastapi import Depends, HTTPException, status, Request, Response
  8. from sqlalchemy.ext.asyncio import AsyncSession
  9. from sqlalchemy import select
  10. import bcrypt
  11. from cryptography.fernet import Fernet
  12. from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
  13. import base64
  14. import os
  15. from app.models import User
  16. from app.database import get_db
  17. from app.config import get_settings
  18. # Session configuration
  19. SESSION_COOKIE_NAME = "session"
  20. SESSION_MAX_AGE = 60 * 60 * 24 * 30 # 30 days
  21. def get_password_hash(password: str) -> str:
  22. """Hash a password using bcrypt."""
  23. # Bcrypt requires bytes
  24. password_bytes = password.encode('utf-8')
  25. salt = bcrypt.gensalt()
  26. hashed = bcrypt.hashpw(password_bytes, salt)
  27. # Return as string for storage
  28. return hashed.decode('utf-8')
  29. def verify_password(plain_password: str, hashed_password: str) -> bool:
  30. """Verify a password against its hash."""
  31. password_bytes = plain_password.encode('utf-8')
  32. hashed_bytes = hashed_password.encode('utf-8')
  33. return bcrypt.checkpw(password_bytes, hashed_bytes)
  34. def get_fernet_key() -> bytes:
  35. """
  36. Get or generate Fernet encryption key for API tokens.
  37. Uses SECRET_KEY from settings to derive a consistent encryption key.
  38. """
  39. settings = get_settings()
  40. # Derive a 32-byte key from SECRET_KEY
  41. key = base64.urlsafe_b64encode(settings.secret_key.encode().ljust(32)[:32])
  42. return key
  43. def encrypt_token(token: str) -> str:
  44. """Encrypt an API token using Fernet."""
  45. fernet = Fernet(get_fernet_key())
  46. return fernet.encrypt(token.encode()).decode()
  47. def decrypt_token(encrypted_token: str) -> str:
  48. """Decrypt an API token using Fernet."""
  49. fernet = Fernet(get_fernet_key())
  50. return fernet.decrypt(encrypted_token.encode()).decode()
  51. def get_serializer() -> URLSafeTimedSerializer:
  52. """Get session serializer."""
  53. settings = get_settings()
  54. return URLSafeTimedSerializer(settings.secret_key)
  55. def create_session_token(user_id: int) -> str:
  56. """Create a signed session token for a user."""
  57. serializer = get_serializer()
  58. return serializer.dumps({"user_id": user_id})
  59. def verify_session_token(token: str, max_age: int = SESSION_MAX_AGE) -> Optional[int]:
  60. """
  61. Verify a session token and return the user_id.
  62. Returns None if token is invalid or expired.
  63. """
  64. serializer = get_serializer()
  65. try:
  66. data = serializer.loads(token, max_age=max_age)
  67. return data.get("user_id")
  68. except (BadSignature, SignatureExpired):
  69. return None
  70. def set_session_cookie(response: Response, user_id: int):
  71. """Set session cookie on response."""
  72. token = create_session_token(user_id)
  73. response.set_cookie(
  74. key=SESSION_COOKIE_NAME,
  75. value=token,
  76. max_age=SESSION_MAX_AGE,
  77. path="/",
  78. httponly=True,
  79. samesite="lax",
  80. # Set secure=True in production with HTTPS
  81. secure=False
  82. )
  83. def clear_session_cookie(response: Response):
  84. """Clear session cookie."""
  85. response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
  86. async def get_current_user(
  87. request: Request,
  88. db: AsyncSession = Depends(get_db)
  89. ) -> User:
  90. """
  91. Get the current authenticated user from session cookie.
  92. Raises 401 Unauthorized if not authenticated.
  93. """
  94. # Get session token from cookie
  95. token = request.cookies.get(SESSION_COOKIE_NAME)
  96. if not token:
  97. raise HTTPException(
  98. status_code=status.HTTP_401_UNAUTHORIZED,
  99. detail="Not authenticated"
  100. )
  101. # Verify token and get user_id
  102. user_id = verify_session_token(token)
  103. if user_id is None:
  104. raise HTTPException(
  105. status_code=status.HTTP_401_UNAUTHORIZED,
  106. detail="Invalid or expired session"
  107. )
  108. # Get user from database
  109. result = await db.execute(
  110. select(User).where(User.id == user_id, User.is_active == True)
  111. )
  112. user = result.scalar_one_or_none()
  113. if not user:
  114. raise HTTPException(
  115. status_code=status.HTTP_401_UNAUTHORIZED,
  116. detail="User not found or inactive"
  117. )
  118. return user
  119. async def get_current_user_optional(
  120. request: Request,
  121. db: AsyncSession = Depends(get_db)
  122. ) -> Optional[User]:
  123. """
  124. Get the current authenticated user from session cookie.
  125. Returns None if not authenticated (does not raise exception).
  126. """
  127. try:
  128. return await get_current_user(request, db)
  129. except HTTPException:
  130. return None
  131. async def authenticate_user(
  132. db: AsyncSession,
  133. username: str,
  134. password: str
  135. ) -> Optional[User]:
  136. """
  137. Authenticate a user by username and password.
  138. Returns User if authentication succeeds, None otherwise.
  139. """
  140. # Find user by username
  141. result = await db.execute(
  142. select(User).where(User.username == username)
  143. )
  144. user = result.scalar_one_or_none()
  145. if not user:
  146. return None
  147. # Verify password
  148. if not verify_password(password, user.hashed_password):
  149. return None
  150. # Check if user is active
  151. if not user.is_active:
  152. return None
  153. # Update last login
  154. user.last_login = datetime.now()
  155. await db.commit()
  156. return user
  157. async def create_user(
  158. db: AsyncSession,
  159. username: str,
  160. email: str,
  161. password: str,
  162. abs_url: str,
  163. abs_api_token: str,
  164. display_name: Optional[str] = None
  165. ) -> User:
  166. """
  167. Create a new user account.
  168. Raises HTTPException if username or email already exists.
  169. """
  170. # Check if username already exists
  171. result = await db.execute(
  172. select(User).where(User.username == username)
  173. )
  174. if result.scalar_one_or_none():
  175. raise HTTPException(
  176. status_code=status.HTTP_400_BAD_REQUEST,
  177. detail="Username already registered"
  178. )
  179. # Check if email already exists
  180. result = await db.execute(
  181. select(User).where(User.email == email)
  182. )
  183. if result.scalar_one_or_none():
  184. raise HTTPException(
  185. status_code=status.HTTP_400_BAD_REQUEST,
  186. detail="Email already registered"
  187. )
  188. # Create new user
  189. user = User(
  190. username=username,
  191. email=email,
  192. hashed_password=get_password_hash(password),
  193. abs_url=abs_url,
  194. abs_api_token=encrypt_token(abs_api_token),
  195. display_name=display_name or username,
  196. created_at=datetime.now(),
  197. is_active=True
  198. )
  199. db.add(user)
  200. await db.commit()
  201. await db.refresh(user)
  202. return user