Initial commit: stock market platform
This commit is contained in:
1
backend/app/api/__init__.py
Normal file
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
35
backend/app/api/deps.py
Normal file
35
backend/app/api/deps.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
|
||||
bearer = HTTPBearer(auto_error=True)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
|
||||
if payload is None or payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效或已过期",
|
||||
)
|
||||
|
||||
user_id = int(payload["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在或已被禁用",
|
||||
)
|
||||
|
||||
return user
|
||||
8
backend/app/api/v1/__init__.py
Normal file
8
backend/app/api/v1/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1 import auth, stocks, watchlist, alerts
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(stocks.router)
|
||||
api_router.include_router(watchlist.router)
|
||||
api_router.include_router(alerts.router)
|
||||
77
backend/app/api/v1/alerts.py
Normal file
77
backend/app/api/v1/alerts.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.alert import Alert
|
||||
from app.schemas.stock import AlertCreate, AlertOut
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[AlertOut])
|
||||
async def get_alerts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert)
|
||||
.where(Alert.user_id == current_user.id)
|
||||
.order_by(Alert.id.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=AlertOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_alert(
|
||||
payload: AlertCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
alert = Alert(
|
||||
user_id=current_user.id,
|
||||
symbol=payload.symbol,
|
||||
name=payload.name,
|
||||
alert_type=payload.alert_type,
|
||||
threshold=payload.threshold,
|
||||
)
|
||||
db.add(alert)
|
||||
await db.flush()
|
||||
await db.refresh(alert)
|
||||
return alert
|
||||
|
||||
|
||||
@router.delete("/{alert_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_alert(
|
||||
alert_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="预警不存在")
|
||||
|
||||
await db.delete(alert)
|
||||
|
||||
|
||||
@router.patch("/{alert_id}/toggle", response_model=AlertOut)
|
||||
async def toggle_alert(
|
||||
alert_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="预警不存在")
|
||||
|
||||
alert.is_active = not alert.is_active
|
||||
await db.flush()
|
||||
await db.refresh(alert)
|
||||
return alert
|
||||
78
backend/app/api/v1/auth.py
Normal file
78
backend/app/api/v1/auth.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.security import (
|
||||
hash_password, verify_password,
|
||||
create_access_token, create_refresh_token, decode_token,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import (
|
||||
RegisterRequest, LoginRequest, TokenResponse,
|
||||
RefreshRequest, UserOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
async def register(payload: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
(User.username == payload.username) | (User.email == payload.email)
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="用户名或邮箱已被使用")
|
||||
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hash_password(payload.password),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == payload.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(payload.password, user.hashed_password):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="账号已被禁用")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
token_data = decode_token(payload.refresh_token)
|
||||
|
||||
if token_data is None or token_data.get("type") != "refresh":
|
||||
raise HTTPException(status_code=401, detail="Refresh token 无效或已过期")
|
||||
|
||||
user_id = int(token_data["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(__import__("app.api.deps", fromlist=["get_current_user"]).get_current_user)):
|
||||
return current_user
|
||||
133
backend/app/api/v1/stocks.py
Normal file
133
backend/app/api/v1/stocks.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.redis import get_redis
|
||||
from app.services import stock_service
|
||||
|
||||
router = APIRouter(prefix="/stocks", tags=["stocks"])
|
||||
|
||||
CACHE_TTL = 30 # seconds
|
||||
|
||||
|
||||
# ── market overview ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/overview")
|
||||
async def market_overview(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:overview"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_market_overview()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── heatmap ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/heatmap")
|
||||
async def market_heatmap(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:heatmap"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_all_stocks_spot()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── sector ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/sectors")
|
||||
async def market_sectors(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:sectors"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_sector_spot()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── single stock quote ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/quote")
|
||||
async def stock_quote(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"quote:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_stock_quote(symbol)
|
||||
if data:
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data or {}
|
||||
|
||||
|
||||
# ── K-line ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/kline")
|
||||
async def stock_kline(
|
||||
symbol: str,
|
||||
period: str = Query("daily", pattern="^(daily|weekly|monthly)$"),
|
||||
adjust: str = Query("qfq", pattern="^(qfq|hfq|)$"),
|
||||
limit: int = Query(250, ge=10, le=1000),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
redis = await get_redis()
|
||||
cache_key = f"kline:{symbol}:{period}:{adjust}:{limit}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_kline(symbol, period, adjust, limit)
|
||||
await redis.setex(cache_key, 300, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── intraday ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/intraday")
|
||||
async def stock_intraday(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"intraday:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_intraday(symbol)
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── 5-day ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/fiveday")
|
||||
async def stock_fiveday(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"fiveday:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_five_day(symbol)
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── search ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/search")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=1),
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await stock_service.search_stocks(q, limit)
|
||||
63
backend/app/api/v1/watchlist.py
Normal file
63
backend/app/api/v1/watchlist.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.watchlist import Watchlist
|
||||
from app.schemas.stock import WatchlistItem, WatchlistAddRequest
|
||||
|
||||
router = APIRouter(prefix="/watchlist", tags=["watchlist"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[WatchlistItem])
|
||||
async def get_watchlist(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Watchlist)
|
||||
.where(Watchlist.user_id == current_user.id)
|
||||
.order_by(Watchlist.sort_order, Watchlist.id)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=WatchlistItem, status_code=status.HTTP_201_CREATED)
|
||||
async def add_to_watchlist(
|
||||
payload: WatchlistAddRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
existing = await db.execute(
|
||||
select(Watchlist).where(
|
||||
Watchlist.user_id == current_user.id,
|
||||
Watchlist.symbol == payload.symbol,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail="已在自选股中")
|
||||
|
||||
item = Watchlist(
|
||||
user_id=current_user.id,
|
||||
symbol=payload.symbol,
|
||||
name=payload.name,
|
||||
)
|
||||
db.add(item)
|
||||
await db.flush()
|
||||
await db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/{symbol}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_from_watchlist(
|
||||
symbol: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await db.execute(
|
||||
delete(Watchlist).where(
|
||||
Watchlist.user_id == current_user.id,
|
||||
Watchlist.symbol == symbol,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user