Initial commit: stock market platform
This commit is contained in:
16
backend/Dockerfile
Normal file
16
backend/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc libpq-dev curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
backend/app/api/__init__.py
Normal file
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
35
backend/app/api/deps.py
Normal file
35
backend/app/api/deps.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
|
||||
bearer = HTTPBearer(auto_error=True)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
|
||||
if payload is None or payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效或已过期",
|
||||
)
|
||||
|
||||
user_id = int(payload["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在或已被禁用",
|
||||
)
|
||||
|
||||
return user
|
||||
8
backend/app/api/v1/__init__.py
Normal file
8
backend/app/api/v1/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1 import auth, stocks, watchlist, alerts
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(stocks.router)
|
||||
api_router.include_router(watchlist.router)
|
||||
api_router.include_router(alerts.router)
|
||||
77
backend/app/api/v1/alerts.py
Normal file
77
backend/app/api/v1/alerts.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.alert import Alert
|
||||
from app.schemas.stock import AlertCreate, AlertOut
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[AlertOut])
|
||||
async def get_alerts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert)
|
||||
.where(Alert.user_id == current_user.id)
|
||||
.order_by(Alert.id.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=AlertOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_alert(
|
||||
payload: AlertCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
alert = Alert(
|
||||
user_id=current_user.id,
|
||||
symbol=payload.symbol,
|
||||
name=payload.name,
|
||||
alert_type=payload.alert_type,
|
||||
threshold=payload.threshold,
|
||||
)
|
||||
db.add(alert)
|
||||
await db.flush()
|
||||
await db.refresh(alert)
|
||||
return alert
|
||||
|
||||
|
||||
@router.delete("/{alert_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_alert(
|
||||
alert_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="预警不存在")
|
||||
|
||||
await db.delete(alert)
|
||||
|
||||
|
||||
@router.patch("/{alert_id}/toggle", response_model=AlertOut)
|
||||
async def toggle_alert(
|
||||
alert_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="预警不存在")
|
||||
|
||||
alert.is_active = not alert.is_active
|
||||
await db.flush()
|
||||
await db.refresh(alert)
|
||||
return alert
|
||||
78
backend/app/api/v1/auth.py
Normal file
78
backend/app/api/v1/auth.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.security import (
|
||||
hash_password, verify_password,
|
||||
create_access_token, create_refresh_token, decode_token,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import (
|
||||
RegisterRequest, LoginRequest, TokenResponse,
|
||||
RefreshRequest, UserOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
async def register(payload: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
(User.username == payload.username) | (User.email == payload.email)
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="用户名或邮箱已被使用")
|
||||
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hash_password(payload.password),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == payload.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(payload.password, user.hashed_password):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="账号已被禁用")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
token_data = decode_token(payload.refresh_token)
|
||||
|
||||
if token_data is None or token_data.get("type") != "refresh":
|
||||
raise HTTPException(status_code=401, detail="Refresh token 无效或已过期")
|
||||
|
||||
user_id = int(token_data["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(__import__("app.api.deps", fromlist=["get_current_user"]).get_current_user)):
|
||||
return current_user
|
||||
133
backend/app/api/v1/stocks.py
Normal file
133
backend/app/api/v1/stocks.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.redis import get_redis
|
||||
from app.services import stock_service
|
||||
|
||||
router = APIRouter(prefix="/stocks", tags=["stocks"])
|
||||
|
||||
CACHE_TTL = 30 # seconds
|
||||
|
||||
|
||||
# ── market overview ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/overview")
|
||||
async def market_overview(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:overview"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_market_overview()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── heatmap ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/heatmap")
|
||||
async def market_heatmap(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:heatmap"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_all_stocks_spot()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── sector ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/market/sectors")
|
||||
async def market_sectors(current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = "market:sectors"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_sector_spot()
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── single stock quote ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/quote")
|
||||
async def stock_quote(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"quote:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_stock_quote(symbol)
|
||||
if data:
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data or {}
|
||||
|
||||
|
||||
# ── K-line ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/kline")
|
||||
async def stock_kline(
|
||||
symbol: str,
|
||||
period: str = Query("daily", pattern="^(daily|weekly|monthly)$"),
|
||||
adjust: str = Query("qfq", pattern="^(qfq|hfq|)$"),
|
||||
limit: int = Query(250, ge=10, le=1000),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
redis = await get_redis()
|
||||
cache_key = f"kline:{symbol}:{period}:{adjust}:{limit}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_kline(symbol, period, adjust, limit)
|
||||
await redis.setex(cache_key, 300, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── intraday ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/intraday")
|
||||
async def stock_intraday(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"intraday:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_intraday(symbol)
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── 5-day ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{symbol}/fiveday")
|
||||
async def stock_fiveday(symbol: str, current_user: User = Depends(get_current_user)):
|
||||
redis = await get_redis()
|
||||
cache_key = f"fiveday:{symbol}"
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
return json.loads(cached)
|
||||
|
||||
data = await stock_service.get_five_day(symbol)
|
||||
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
|
||||
return data
|
||||
|
||||
|
||||
# ── search ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/search")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=1),
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await stock_service.search_stocks(q, limit)
|
||||
63
backend/app/api/v1/watchlist.py
Normal file
63
backend/app/api/v1/watchlist.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.watchlist import Watchlist
|
||||
from app.schemas.stock import WatchlistItem, WatchlistAddRequest
|
||||
|
||||
router = APIRouter(prefix="/watchlist", tags=["watchlist"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[WatchlistItem])
|
||||
async def get_watchlist(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Watchlist)
|
||||
.where(Watchlist.user_id == current_user.id)
|
||||
.order_by(Watchlist.sort_order, Watchlist.id)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=WatchlistItem, status_code=status.HTTP_201_CREATED)
|
||||
async def add_to_watchlist(
|
||||
payload: WatchlistAddRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
existing = await db.execute(
|
||||
select(Watchlist).where(
|
||||
Watchlist.user_id == current_user.id,
|
||||
Watchlist.symbol == payload.symbol,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail="已在自选股中")
|
||||
|
||||
item = Watchlist(
|
||||
user_id=current_user.id,
|
||||
symbol=payload.symbol,
|
||||
name=payload.name,
|
||||
)
|
||||
db.add(item)
|
||||
await db.flush()
|
||||
await db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/{symbol}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_from_watchlist(
|
||||
symbol: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await db.execute(
|
||||
delete(Watchlist).where(
|
||||
Watchlist.user_id == current_user.id,
|
||||
Watchlist.symbol == symbol,
|
||||
)
|
||||
)
|
||||
41
backend/app/core/config.py
Normal file
41
backend/app/core/config.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import field_validator
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# App
|
||||
APP_NAME: str = "Stock Platform API"
|
||||
DEBUG: bool = False
|
||||
API_V1_PREFIX: str = "/api/v1"
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "postgresql+asyncpg://stock:stockpass@localhost:5432/stockdb"
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||
|
||||
# Auth
|
||||
SECRET_KEY: str = "change-me-in-production"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# CORS
|
||||
ALLOWED_ORIGINS: str = "http://localhost:3000,http://localhost:5173"
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> List[str]:
|
||||
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",")]
|
||||
|
||||
# Tushare (optional)
|
||||
TUSHARE_TOKEN: str = ""
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
37
backend/app/core/database.py
Normal file
37
backend/app/core/database.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
22
backend/app/core/redis.py
Normal file
22
backend/app/core/redis.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import redis.asyncio as aioredis
|
||||
from app.core.config import settings
|
||||
|
||||
_redis_pool: aioredis.Redis | None = None
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
global _redis_pool
|
||||
if _redis_pool is None:
|
||||
_redis_pool = aioredis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return _redis_pool
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
global _redis_pool
|
||||
if _redis_pool:
|
||||
await _redis_pool.aclose()
|
||||
_redis_pool = None
|
||||
37
backend/app/core/security.py
Normal file
37
backend/app/core/security.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(subject: str | int, expires_delta: Optional[timedelta] = None) -> str:
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
)
|
||||
payload = {"sub": str(subject), "exp": expire, "type": "access"}
|
||||
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(subject: str | int) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
payload = {"sub": str(subject), "exp": expire, "type": "refresh"}
|
||||
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
89
backend/app/main.py
Normal file
89
backend/app/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.core.redis import close_redis
|
||||
from app.api.v1 import api_router
|
||||
from app.websocket.manager import manager
|
||||
from app.services import stock_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
|
||||
# ── background task: push heatmap every 5s ───────────────────────────────────
|
||||
|
||||
async def _heatmap_pusher():
|
||||
while True:
|
||||
try:
|
||||
data = await stock_service.get_all_stocks_spot()
|
||||
await manager.broadcast_all(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"heatmap pusher error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting up...")
|
||||
await init_db()
|
||||
task = asyncio.create_task(_heatmap_pusher())
|
||||
yield
|
||||
task.cancel()
|
||||
await close_redis()
|
||||
logger.info("Shut down.")
|
||||
|
||||
|
||||
# ── app ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
docs_url="/api/docs",
|
||||
redoc_url="/api/redoc",
|
||||
openapi_url="/api/openapi.json",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_PREFIX)
|
||||
|
||||
|
||||
# ── WebSocket endpoints ───────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/ws/heatmap")
|
||||
async def ws_heatmap(websocket: WebSocket):
|
||||
"""Subscribe to full market heatmap updates."""
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.websocket("/ws/quote/{symbol}")
|
||||
async def ws_quote(websocket: WebSocket, symbol: str):
|
||||
"""Subscribe to real-time quote updates for a single stock."""
|
||||
await manager.connect(websocket, symbol)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
data = await stock_service.get_stock_quote(symbol)
|
||||
if data:
|
||||
await manager.broadcast_quote(symbol, data)
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, symbol)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
5
backend/app/models/__init__.py
Normal file
5
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from app.models.user import User
|
||||
from app.models.watchlist import Watchlist
|
||||
from app.models.alert import Alert, AlertType
|
||||
|
||||
__all__ = ["User", "Watchlist", "Alert", "AlertType"]
|
||||
29
backend/app/models/alert.py
Normal file
29
backend/app/models/alert.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Float, Boolean, ForeignKey, DateTime, func, Enum
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
import enum
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class AlertType(str, enum.Enum):
|
||||
PRICE_ABOVE = "price_above"
|
||||
PRICE_BELOW = "price_below"
|
||||
CHANGE_PCT_ABOVE = "change_pct_above"
|
||||
CHANGE_PCT_BELOW = "change_pct_below"
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
symbol: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
alert_type: Mapped[AlertType] = mapped_column(Enum(AlertType), nullable=False)
|
||||
threshold: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
triggered: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
triggered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="alerts")
|
||||
22
backend/app/models/user.py
Normal file
22
backend/app/models/user.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Boolean, DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
watchlist: Mapped[list["Watchlist"]] = relationship(back_populates="user", lazy="select")
|
||||
alerts: Mapped[list["Alert"]] = relationship(back_populates="user", lazy="select")
|
||||
18
backend/app/models/watchlist.py
Normal file
18
backend/app/models/watchlist.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Integer, ForeignKey, DateTime, func, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class Watchlist(Base):
|
||||
__tablename__ = "watchlist"
|
||||
__table_args__ = (UniqueConstraint("user_id", "symbol", name="uq_user_symbol"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
symbol: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="watchlist")
|
||||
1
backend/app/schemas/__init__.py
Normal file
1
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
47
backend/app/schemas/auth.py
Normal file
47
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
import re
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def username_alphanumeric(cls, v: str) -> str:
|
||||
if not re.match(r"^[a-zA-Z0-9_]{3,20}$", v):
|
||||
raise ValueError("用户名只能包含字母、数字、下划线,长度3-20位")
|
||||
return v
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_min_length(cls, v: str) -> str:
|
||||
if len(v) < 6:
|
||||
raise ValueError("密码至少6位")
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
is_active: bool
|
||||
is_admin: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
91
backend/app/schemas/stock.py
Normal file
91
backend/app/schemas/stock.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class StockQuote(BaseModel):
|
||||
symbol: str
|
||||
name: str
|
||||
price: float
|
||||
change: float
|
||||
change_pct: float
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
prev_close: float
|
||||
volume: float
|
||||
amount: float
|
||||
market_cap: Optional[float] = None
|
||||
pe_ratio: Optional[float] = None
|
||||
sector: Optional[str] = None
|
||||
|
||||
|
||||
class StockSearchResult(BaseModel):
|
||||
symbol: str
|
||||
name: str
|
||||
market: str
|
||||
|
||||
|
||||
class KLineBar(BaseModel):
|
||||
date: str
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
amount: Optional[float] = None
|
||||
change_pct: Optional[float] = None
|
||||
|
||||
|
||||
class IntraDayBar(BaseModel):
|
||||
time: str
|
||||
price: float
|
||||
volume: float
|
||||
amount: Optional[float] = None
|
||||
avg_price: Optional[float] = None
|
||||
|
||||
|
||||
class MarketOverview(BaseModel):
|
||||
index_code: str
|
||||
index_name: str
|
||||
current: float
|
||||
change: float
|
||||
change_pct: float
|
||||
|
||||
|
||||
class SectorData(BaseModel):
|
||||
sector: str
|
||||
change_pct: float
|
||||
stocks: list[StockQuote] = []
|
||||
|
||||
|
||||
class WatchlistItem(BaseModel):
|
||||
id: int
|
||||
symbol: str
|
||||
name: str
|
||||
sort_order: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class WatchlistAddRequest(BaseModel):
|
||||
symbol: str
|
||||
name: str
|
||||
|
||||
|
||||
class AlertCreate(BaseModel):
|
||||
symbol: str
|
||||
name: str
|
||||
alert_type: str
|
||||
threshold: float
|
||||
|
||||
|
||||
class AlertOut(BaseModel):
|
||||
id: int
|
||||
symbol: str
|
||||
name: str
|
||||
alert_type: str
|
||||
threshold: float
|
||||
is_active: bool
|
||||
triggered: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
332
backend/app/services/stock_service.py
Normal file
332
backend/app/services/stock_service.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Stock data service — wraps AKShare with Redis caching.
|
||||
All AKShare calls are blocking I/O; run them in a thread pool via asyncio.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import akshare as ak
|
||||
AK_AVAILABLE = True
|
||||
except Exception:
|
||||
AK_AVAILABLE = False
|
||||
logger.warning("AKShare not available, using mock data")
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _run_sync(fn, *args, **kwargs):
|
||||
"""Run a sync blocking function in the default thread pool."""
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.run_in_executor(None, lambda: fn(*args, **kwargs))
|
||||
|
||||
|
||||
# ── market overview ───────────────────────────────────────────────────────────
|
||||
|
||||
MAJOR_INDICES = [
|
||||
("sh000001", "上证指数"),
|
||||
("sz399001", "深证成指"),
|
||||
("sz399006", "创业板指"),
|
||||
("sh000688", "科创50"),
|
||||
("sh000300", "沪深300"),
|
||||
]
|
||||
|
||||
|
||||
async def get_market_overview() -> list[dict]:
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_market_overview()
|
||||
|
||||
try:
|
||||
df = await _run_sync(ak.stock_zh_index_spot_em)
|
||||
result = []
|
||||
code_map = {code: name for code, name in MAJOR_INDICES}
|
||||
for _, row in df.iterrows():
|
||||
code = str(row.get("代码", ""))
|
||||
if code in code_map:
|
||||
result.append({
|
||||
"index_code": code,
|
||||
"index_name": code_map[code],
|
||||
"current": float(row.get("最新价", 0)),
|
||||
"change": float(row.get("涨跌额", 0)),
|
||||
"change_pct": float(row.get("涨跌幅", 0)),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_market_overview error: {e}")
|
||||
return _mock_market_overview()
|
||||
|
||||
|
||||
# ── real-time quotes (A-share spot) ──────────────────────────────────────────
|
||||
|
||||
async def get_all_stocks_spot() -> list[dict]:
|
||||
"""All A-share real-time quotes — used for heatmap."""
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_heatmap_data()
|
||||
|
||||
try:
|
||||
df = await _run_sync(ak.stock_zh_a_spot_em)
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
pct = float(row.get("涨跌幅", 0) or 0)
|
||||
result.append({
|
||||
"symbol": str(row.get("代码", "")),
|
||||
"name": str(row.get("名称", "")),
|
||||
"price": float(row.get("最新价", 0) or 0),
|
||||
"change": float(row.get("涨跌额", 0) or 0),
|
||||
"change_pct": pct,
|
||||
"open": float(row.get("今开", 0) or 0),
|
||||
"high": float(row.get("最高", 0) or 0),
|
||||
"low": float(row.get("最低", 0) or 0),
|
||||
"prev_close": float(row.get("昨收", 0) or 0),
|
||||
"volume": float(row.get("成交量", 0) or 0),
|
||||
"amount": float(row.get("成交额", 0) or 0),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_stocks_spot error: {e}")
|
||||
return _mock_heatmap_data()
|
||||
|
||||
|
||||
async def get_stock_quote(symbol: str) -> Optional[dict]:
|
||||
"""Single stock real-time quote."""
|
||||
all_stocks = await get_all_stocks_spot()
|
||||
for s in all_stocks:
|
||||
if s["symbol"] == symbol:
|
||||
return s
|
||||
return None
|
||||
|
||||
|
||||
# ── K-line data ───────────────────────────────────────────────────────────────
|
||||
|
||||
PERIOD_MAP = {
|
||||
"daily": "daily",
|
||||
"weekly": "weekly",
|
||||
"monthly": "monthly",
|
||||
}
|
||||
|
||||
|
||||
async def get_kline(symbol: str, period: str = "daily", adjust: str = "qfq", limit: int = 250) -> list[dict]:
|
||||
"""
|
||||
period: daily | weekly | monthly
|
||||
adjust: qfq (前复权) | hfq (后复权) | "" (不复权)
|
||||
"""
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_kline(limit)
|
||||
|
||||
ak_period = PERIOD_MAP.get(period, "daily")
|
||||
|
||||
try:
|
||||
df = await _run_sync(
|
||||
ak.stock_zh_a_hist,
|
||||
symbol=symbol,
|
||||
period=ak_period,
|
||||
adjust=adjust,
|
||||
)
|
||||
df = df.tail(limit)
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
result.append({
|
||||
"date": str(row.get("日期", "")),
|
||||
"open": float(row.get("开盘", 0)),
|
||||
"high": float(row.get("最高", 0)),
|
||||
"low": float(row.get("最低", 0)),
|
||||
"close": float(row.get("收盘", 0)),
|
||||
"volume": float(row.get("成交量", 0)),
|
||||
"amount": float(row.get("成交额", 0)),
|
||||
"change_pct": float(row.get("涨跌幅", 0)),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_kline {symbol} {period} error: {e}")
|
||||
return _mock_kline(limit)
|
||||
|
||||
|
||||
async def get_intraday(symbol: str) -> list[dict]:
|
||||
"""Today's minute-level data."""
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_intraday()
|
||||
|
||||
try:
|
||||
df = await _run_sync(ak.stock_zh_a_hist_min_em, symbol=symbol, period="1", adjust="")
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
result.append({
|
||||
"time": str(row.get("时间", "")),
|
||||
"price": float(row.get("收盘", 0)),
|
||||
"volume": float(row.get("成交量", 0)),
|
||||
"amount": float(row.get("成交额", 0)),
|
||||
"avg_price": float(row.get("均价", 0) or 0),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_intraday {symbol} error: {e}")
|
||||
return _mock_intraday()
|
||||
|
||||
|
||||
async def get_five_day(symbol: str) -> list[dict]:
|
||||
"""5-day minute-level data."""
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_intraday(days=5)
|
||||
|
||||
try:
|
||||
df = await _run_sync(ak.stock_zh_a_hist_min_em, symbol=symbol, period="1", adjust="")
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
result.append({
|
||||
"time": str(row.get("时间", "")),
|
||||
"price": float(row.get("收盘", 0)),
|
||||
"volume": float(row.get("成交量", 0)),
|
||||
"amount": float(row.get("成交额", 0) or 0),
|
||||
"avg_price": float(row.get("均价", 0) or 0),
|
||||
})
|
||||
return result[-5 * 240:]
|
||||
except Exception as e:
|
||||
logger.error(f"get_five_day {symbol} error: {e}")
|
||||
return _mock_intraday(days=5)
|
||||
|
||||
|
||||
# ── search ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_stock_list_cached() -> list[dict]:
|
||||
"""Cache the full stock list in memory (refreshed on process restart)."""
|
||||
if not AK_AVAILABLE:
|
||||
return []
|
||||
try:
|
||||
import akshare as ak
|
||||
df = ak.stock_info_a_code_name()
|
||||
return [{"symbol": str(r["code"]), "name": str(r["name"]), "market": "A股"} for _, r in df.iterrows()]
|
||||
except Exception as e:
|
||||
logger.error(f"_get_stock_list_cached error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def search_stocks(query: str, limit: int = 20) -> list[dict]:
|
||||
stock_list = await _run_sync(_get_stock_list_cached)
|
||||
query = query.strip().lower()
|
||||
results = [
|
||||
s for s in stock_list
|
||||
if query in s["symbol"].lower() or query in s["name"].lower()
|
||||
]
|
||||
return results[:limit]
|
||||
|
||||
|
||||
# ── sector data for heatmap ───────────────────────────────────────────────────
|
||||
|
||||
async def get_sector_spot() -> list[dict]:
|
||||
"""Board/sector change pct for heatmap grouping."""
|
||||
if not AK_AVAILABLE:
|
||||
return _mock_sectors()
|
||||
|
||||
try:
|
||||
df = await _run_sync(ak.stock_board_industry_name_em)
|
||||
result = []
|
||||
for _, row in df.iterrows():
|
||||
result.append({
|
||||
"sector": str(row.get("板块名称", "")),
|
||||
"change_pct": float(row.get("涨跌幅", 0) or 0),
|
||||
"volume": float(row.get("成交量", 0) or 0),
|
||||
"amount": float(row.get("成交额", 0) or 0),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_sector_spot error: {e}")
|
||||
return _mock_sectors()
|
||||
|
||||
|
||||
# ── mock data (fallback when market is closed or AKShare unavailable) ─────────
|
||||
|
||||
import random
|
||||
import math
|
||||
from datetime import date, timedelta
|
||||
|
||||
|
||||
def _mock_market_overview() -> list[dict]:
|
||||
return [
|
||||
{"index_code": "sh000001", "index_name": "上证指数", "current": 3312.46, "change": -28.84, "change_pct": -0.86},
|
||||
{"index_code": "sz399001", "index_name": "深证成指", "current": 10573.99, "change": -93.16, "change_pct": -0.87},
|
||||
{"index_code": "sz399006", "index_name": "创业板指", "current": 2105.37, "change": -18.42, "change_pct": -0.87},
|
||||
{"index_code": "sh000688", "index_name": "科创50", "current": 968.12, "change": -9.56, "change_pct": -0.98},
|
||||
{"index_code": "sh000300", "index_name": "沪深300", "current": 3843.20, "change": -30.11, "change_pct": -0.78},
|
||||
]
|
||||
|
||||
|
||||
def _mock_heatmap_data() -> list[dict]:
|
||||
sectors = ["银行", "电力设备", "食品饮料", "医药生物", "电子", "汽车", "非银金融", "计算机", "有色金属", "化工"]
|
||||
stocks = []
|
||||
for i in range(80):
|
||||
pct = round(random.uniform(-5, 5), 2)
|
||||
sector = sectors[i % len(sectors)]
|
||||
stocks.append({
|
||||
"symbol": f"{600000 + i:06d}",
|
||||
"name": f"测试股票{i+1:02d}",
|
||||
"price": round(random.uniform(5, 100), 2),
|
||||
"change": round(pct * 0.1, 2),
|
||||
"change_pct": pct,
|
||||
"open": round(random.uniform(5, 100), 2),
|
||||
"high": round(random.uniform(5, 100), 2),
|
||||
"low": round(random.uniform(5, 100), 2),
|
||||
"prev_close": round(random.uniform(5, 100), 2),
|
||||
"volume": random.randint(100000, 10000000),
|
||||
"amount": random.randint(1000000, 100000000),
|
||||
"sector": sector,
|
||||
})
|
||||
return stocks
|
||||
|
||||
|
||||
def _mock_kline(limit: int = 250) -> list[dict]:
|
||||
bars = []
|
||||
price = 20.0
|
||||
today = date.today()
|
||||
for i in range(limit):
|
||||
d = today - timedelta(days=limit - i)
|
||||
pct = random.uniform(-0.05, 0.05)
|
||||
close = round(price * (1 + pct), 2)
|
||||
high = round(max(price, close) * random.uniform(1.0, 1.03), 2)
|
||||
low = round(min(price, close) * random.uniform(0.97, 1.0), 2)
|
||||
bars.append({
|
||||
"date": d.isoformat(),
|
||||
"open": round(price, 2),
|
||||
"high": high,
|
||||
"low": low,
|
||||
"close": close,
|
||||
"volume": random.randint(500000, 5000000),
|
||||
"amount": random.randint(5000000, 50000000),
|
||||
"change_pct": round(pct * 100, 2),
|
||||
})
|
||||
price = close
|
||||
return bars
|
||||
|
||||
|
||||
def _mock_intraday(days: int = 1) -> list[dict]:
|
||||
bars = []
|
||||
price = 20.0
|
||||
for d in range(days):
|
||||
for minute in range(240):
|
||||
h = 9 + minute // 60
|
||||
m = (minute % 60) + (30 if d == 0 and minute < 60 else 0)
|
||||
if h == 11 and m >= 30:
|
||||
continue
|
||||
pct = random.uniform(-0.01, 0.01)
|
||||
price = round(price * (1 + pct), 2)
|
||||
bars.append({
|
||||
"time": f"2026-06-0{d+1} {h:02d}:{m % 60:02d}",
|
||||
"price": price,
|
||||
"volume": random.randint(10000, 500000),
|
||||
"amount": random.randint(100000, 5000000),
|
||||
"avg_price": price,
|
||||
})
|
||||
return bars
|
||||
|
||||
|
||||
def _mock_sectors() -> list[dict]:
|
||||
sectors = [
|
||||
"银行", "电力设备", "食品饮料", "医药生物", "电子",
|
||||
"汽车", "非银金融", "计算机", "有色金属", "化工",
|
||||
"机械设备", "建筑材料", "传媒", "房地产", "交通运输",
|
||||
]
|
||||
return [{"sector": s, "change_pct": round(random.uniform(-3, 3), 2), "volume": 0, "amount": 0}
|
||||
for s in sectors]
|
||||
1
backend/app/websocket/__init__.py
Normal file
1
backend/app/websocket/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
58
backend/app/websocket/manager.py
Normal file
58
backend/app/websocket/manager.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
WebSocket connection manager — broadcasts real-time stock quotes to subscribers.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
# symbol -> set of WebSocket connections
|
||||
self._subs: dict[str, set[WebSocket]] = defaultdict(set)
|
||||
self._all: set[WebSocket] = set()
|
||||
|
||||
async def connect(self, ws: WebSocket, symbol: str | None = None):
|
||||
await ws.accept()
|
||||
self._all.add(ws)
|
||||
if symbol:
|
||||
self._subs[symbol].add(ws)
|
||||
logger.info(f"WS connected symbol={symbol}, total={len(self._all)}")
|
||||
|
||||
def disconnect(self, ws: WebSocket, symbol: str | None = None):
|
||||
self._all.discard(ws)
|
||||
if symbol:
|
||||
self._subs[symbol].discard(ws)
|
||||
else:
|
||||
for s in list(self._subs.keys()):
|
||||
self._subs[s].discard(ws)
|
||||
logger.info(f"WS disconnected, total={len(self._all)}")
|
||||
|
||||
async def broadcast_quote(self, symbol: str, data: dict):
|
||||
"""Send quote update to all subscribers of a specific symbol."""
|
||||
message = json.dumps({"type": "quote", "symbol": symbol, "data": data})
|
||||
dead = set()
|
||||
for ws in list(self._subs.get(symbol, [])):
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
dead.add(ws)
|
||||
for ws in dead:
|
||||
self.disconnect(ws, symbol)
|
||||
|
||||
async def broadcast_all(self, data: list[dict]):
|
||||
"""Broadcast heatmap snapshot to all connected clients."""
|
||||
message = json.dumps({"type": "heatmap", "data": data})
|
||||
dead = set()
|
||||
for ws in list(self._all):
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
dead.add(ws)
|
||||
for ws in dead:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
1
backend/celery_app/__init__.py
Normal file
1
backend/celery_app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
backend/celery_app/tasks/__init__.py
Normal file
1
backend/celery_app/tasks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
36
backend/celery_app/tasks/market_tasks.py
Normal file
36
backend/celery_app/tasks/market_tasks.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import asyncio
|
||||
import redis as sync_redis
|
||||
from celery_app.worker import app
|
||||
from app.core.config import settings
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def _sync_redis():
|
||||
return sync_redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
@app.task(name="celery_app.tasks.market_tasks.refresh_heatmap_cache")
|
||||
def refresh_heatmap_cache():
|
||||
"""Pull all A-share spot quotes and push to Redis cache."""
|
||||
try:
|
||||
from app.services.stock_service import get_all_stocks_spot
|
||||
data = asyncio.run(get_all_stocks_spot())
|
||||
r = _sync_redis()
|
||||
r.setex("market:heatmap", 60, json.dumps(data))
|
||||
logger.info(f"Heatmap cache refreshed: {len(data)} stocks")
|
||||
except Exception as e:
|
||||
logger.error(f"refresh_heatmap_cache error: {e}")
|
||||
|
||||
|
||||
@app.task(name="celery_app.tasks.market_tasks.refresh_market_overview")
|
||||
def refresh_market_overview():
|
||||
"""Pull major index data and push to Redis cache."""
|
||||
try:
|
||||
from app.services.stock_service import get_market_overview
|
||||
data = asyncio.run(get_market_overview())
|
||||
r = _sync_redis()
|
||||
r.setex("market:overview", 120, json.dumps(data))
|
||||
logger.info(f"Market overview refreshed: {len(data)} indices")
|
||||
except Exception as e:
|
||||
logger.error(f"refresh_market_overview error: {e}")
|
||||
31
backend/celery_app/worker.py
Normal file
31
backend/celery_app/worker.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
import os
|
||||
import sys
|
||||
|
||||
app = Celery(
|
||||
"stock_worker",
|
||||
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/1"),
|
||||
backend=os.getenv("REDIS_URL", "redis://localhost:6379/0"),
|
||||
include=["celery_app.tasks.market_tasks"],
|
||||
)
|
||||
|
||||
app.conf.timezone = "Asia/Shanghai"
|
||||
app.conf.enable_utc = False
|
||||
|
||||
# Windows does not support the default prefork pool (fork syscall unavailable)
|
||||
if sys.platform == "win32":
|
||||
app.conf.worker_pool = "solo"
|
||||
|
||||
app.conf.beat_schedule = {
|
||||
# Refresh heatmap cache every 30s during trading hours
|
||||
"refresh-heatmap-30s": {
|
||||
"task": "celery_app.tasks.market_tasks.refresh_heatmap_cache",
|
||||
"schedule": 30.0,
|
||||
},
|
||||
# Refresh index data every minute
|
||||
"refresh-indices-1m": {
|
||||
"task": "celery_app.tasks.market_tasks.refresh_market_overview",
|
||||
"schedule": 60.0,
|
||||
},
|
||||
}
|
||||
34
backend/requirements.txt
Normal file
34
backend/requirements.txt
Normal file
@@ -0,0 +1,34 @@
|
||||
# Web framework
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.32.1
|
||||
python-multipart==0.0.20
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.36
|
||||
asyncpg==0.30.0
|
||||
alembic==1.14.0
|
||||
|
||||
# Redis & Celery
|
||||
redis==5.2.1
|
||||
celery==5.4.0
|
||||
celery[redis]==5.4.0
|
||||
|
||||
# Auth
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
# Stock data
|
||||
akshare==1.18.64
|
||||
|
||||
# HTTP client
|
||||
httpx==0.28.1
|
||||
aiohttp==3.11.11
|
||||
|
||||
# Validation & settings
|
||||
pydantic==2.10.3
|
||||
pydantic-settings==2.7.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
loguru==0.7.3
|
||||
tenacity==9.0.0
|
||||
Reference in New Issue
Block a user