79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from app.core.database import get_db
|
|
from app.core.security import (
|
|
hash_password, verify_password,
|
|
create_access_token, create_refresh_token, decode_token,
|
|
)
|
|
from app.models.user import User
|
|
from app.schemas.auth import (
|
|
RegisterRequest, LoginRequest, TokenResponse,
|
|
RefreshRequest, UserOut,
|
|
)
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
|
async def register(payload: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(User).where(
|
|
(User.username == payload.username) | (User.email == payload.email)
|
|
)
|
|
)
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(status_code=400, detail="用户名或邮箱已被使用")
|
|
|
|
user = User(
|
|
username=payload.username,
|
|
email=payload.email,
|
|
hashed_password=hash_password(payload.password),
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
|
|
@router.post("/login", response_model=TokenResponse)
|
|
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(User).where(User.username == payload.username))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(payload.password, user.hashed_password):
|
|
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=403, detail="账号已被禁用")
|
|
|
|
return TokenResponse(
|
|
access_token=create_access_token(user.id),
|
|
refresh_token=create_refresh_token(user.id),
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=TokenResponse)
|
|
async def refresh(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
|
token_data = decode_token(payload.refresh_token)
|
|
|
|
if token_data is None or token_data.get("type") != "refresh":
|
|
raise HTTPException(status_code=401, detail="Refresh token 无效或已过期")
|
|
|
|
user_id = int(token_data["sub"])
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not user.is_active:
|
|
raise HTTPException(status_code=401, detail="用户不存在")
|
|
|
|
return TokenResponse(
|
|
access_token=create_access_token(user.id),
|
|
refresh_token=create_refresh_token(user.id),
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserOut)
|
|
async def get_me(current_user: User = Depends(__import__("app.api.deps", fromlist=["get_current_user"]).get_current_user)):
|
|
return current_user
|