Initial commit: stock market platform

This commit is contained in:
admin
2026-06-11 01:41:47 +08:00
commit 63718906e9
62 changed files with 8962 additions and 0 deletions

16
backend/Dockerfile Normal file
View File

@@ -0,0 +1,16 @@
FROM python:3.12-slim
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc libpq-dev curl \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

35
backend/app/api/deps.py Normal file
View File

@@ -0,0 +1,35 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.security import decode_token
from app.models.user import User
bearer = HTTPBearer(auto_error=True)
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer),
db: AsyncSession = Depends(get_db),
) -> User:
token = credentials.credentials
payload = decode_token(token)
if payload is None or payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期",
)
user_id = int(payload["sub"])
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在或已被禁用",
)
return user

View File

@@ -0,0 +1,8 @@
from fastapi import APIRouter
from app.api.v1 import auth, stocks, watchlist, alerts
api_router = APIRouter()
api_router.include_router(auth.router)
api_router.include_router(stocks.router)
api_router.include_router(watchlist.router)
api_router.include_router(alerts.router)

View File

@@ -0,0 +1,77 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from app.core.database import get_db
from app.api.deps import get_current_user
from app.models.user import User
from app.models.alert import Alert
from app.schemas.stock import AlertCreate, AlertOut
router = APIRouter(prefix="/alerts", tags=["alerts"])
@router.get("", response_model=list[AlertOut])
async def get_alerts(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Alert)
.where(Alert.user_id == current_user.id)
.order_by(Alert.id.desc())
)
return result.scalars().all()
@router.post("", response_model=AlertOut, status_code=status.HTTP_201_CREATED)
async def create_alert(
payload: AlertCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
alert = Alert(
user_id=current_user.id,
symbol=payload.symbol,
name=payload.name,
alert_type=payload.alert_type,
threshold=payload.threshold,
)
db.add(alert)
await db.flush()
await db.refresh(alert)
return alert
@router.delete("/{alert_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_alert(
alert_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
)
alert = result.scalar_one_or_none()
if not alert:
raise HTTPException(status_code=404, detail="预警不存在")
await db.delete(alert)
@router.patch("/{alert_id}/toggle", response_model=AlertOut)
async def toggle_alert(
alert_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Alert).where(Alert.id == alert_id, Alert.user_id == current_user.id)
)
alert = result.scalar_one_or_none()
if not alert:
raise HTTPException(status_code=404, detail="预警不存在")
alert.is_active = not alert.is_active
await db.flush()
await db.refresh(alert)
return alert

View File

@@ -0,0 +1,78 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.security import (
hash_password, verify_password,
create_access_token, create_refresh_token, decode_token,
)
from app.models.user import User
from app.schemas.auth import (
RegisterRequest, LoginRequest, TokenResponse,
RefreshRequest, UserOut,
)
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(payload: RegisterRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(User).where(
(User.username == payload.username) | (User.email == payload.email)
)
)
if result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="用户名或邮箱已被使用")
user = User(
username=payload.username,
email=payload.email,
hashed_password=hash_password(payload.password),
)
db.add(user)
await db.flush()
await db.refresh(user)
return user
@router.post("/login", response_model=TokenResponse)
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.username == payload.username))
user = result.scalar_one_or_none()
if not user or not verify_password(payload.password, user.hashed_password):
raise HTTPException(status_code=401, detail="用户名或密码错误")
if not user.is_active:
raise HTTPException(status_code=403, detail="账号已被禁用")
return TokenResponse(
access_token=create_access_token(user.id),
refresh_token=create_refresh_token(user.id),
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
token_data = decode_token(payload.refresh_token)
if token_data is None or token_data.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Refresh token 无效或已过期")
user_id = int(token_data["sub"])
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(status_code=401, detail="用户不存在")
return TokenResponse(
access_token=create_access_token(user.id),
refresh_token=create_refresh_token(user.id),
)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(__import__("app.api.deps", fromlist=["get_current_user"]).get_current_user)):
return current_user

View File

@@ -0,0 +1,133 @@
import json
from fastapi import APIRouter, Depends, Query
from app.api.deps import get_current_user
from app.models.user import User
from app.core.redis import get_redis
from app.services import stock_service
router = APIRouter(prefix="/stocks", tags=["stocks"])
CACHE_TTL = 30 # seconds
# ── market overview ───────────────────────────────────────────────────────────
@router.get("/market/overview")
async def market_overview(current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = "market:overview"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_market_overview()
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data
# ── heatmap ───────────────────────────────────────────────────────────────────
@router.get("/market/heatmap")
async def market_heatmap(current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = "market:heatmap"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_all_stocks_spot()
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data
# ── sector ────────────────────────────────────────────────────────────────────
@router.get("/market/sectors")
async def market_sectors(current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = "market:sectors"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_sector_spot()
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data
# ── single stock quote ────────────────────────────────────────────────────────
@router.get("/{symbol}/quote")
async def stock_quote(symbol: str, current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = f"quote:{symbol}"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_stock_quote(symbol)
if data:
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data or {}
# ── K-line ────────────────────────────────────────────────────────────────────
@router.get("/{symbol}/kline")
async def stock_kline(
symbol: str,
period: str = Query("daily", pattern="^(daily|weekly|monthly)$"),
adjust: str = Query("qfq", pattern="^(qfq|hfq|)$"),
limit: int = Query(250, ge=10, le=1000),
current_user: User = Depends(get_current_user),
):
redis = await get_redis()
cache_key = f"kline:{symbol}:{period}:{adjust}:{limit}"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_kline(symbol, period, adjust, limit)
await redis.setex(cache_key, 300, json.dumps(data))
return data
# ── intraday ──────────────────────────────────────────────────────────────────
@router.get("/{symbol}/intraday")
async def stock_intraday(symbol: str, current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = f"intraday:{symbol}"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_intraday(symbol)
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data
# ── 5-day ─────────────────────────────────────────────────────────────────────
@router.get("/{symbol}/fiveday")
async def stock_fiveday(symbol: str, current_user: User = Depends(get_current_user)):
redis = await get_redis()
cache_key = f"fiveday:{symbol}"
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
data = await stock_service.get_five_day(symbol)
await redis.setex(cache_key, CACHE_TTL, json.dumps(data))
return data
# ── search ────────────────────────────────────────────────────────────────────
@router.get("/search")
async def search(
q: str = Query(..., min_length=1),
limit: int = Query(20, ge=1, le=50),
current_user: User = Depends(get_current_user),
):
return await stock_service.search_stocks(q, limit)

View File

@@ -0,0 +1,63 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from app.core.database import get_db
from app.api.deps import get_current_user
from app.models.user import User
from app.models.watchlist import Watchlist
from app.schemas.stock import WatchlistItem, WatchlistAddRequest
router = APIRouter(prefix="/watchlist", tags=["watchlist"])
@router.get("", response_model=list[WatchlistItem])
async def get_watchlist(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Watchlist)
.where(Watchlist.user_id == current_user.id)
.order_by(Watchlist.sort_order, Watchlist.id)
)
return result.scalars().all()
@router.post("", response_model=WatchlistItem, status_code=status.HTTP_201_CREATED)
async def add_to_watchlist(
payload: WatchlistAddRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
existing = await db.execute(
select(Watchlist).where(
Watchlist.user_id == current_user.id,
Watchlist.symbol == payload.symbol,
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail="已在自选股中")
item = Watchlist(
user_id=current_user.id,
symbol=payload.symbol,
name=payload.name,
)
db.add(item)
await db.flush()
await db.refresh(item)
return item
@router.delete("/{symbol}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_from_watchlist(
symbol: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await db.execute(
delete(Watchlist).where(
Watchlist.user_id == current_user.id,
Watchlist.symbol == symbol,
)
)

View File

@@ -0,0 +1,41 @@
from pydantic_settings import BaseSettings
from pydantic import field_validator
from typing import List
import os
class Settings(BaseSettings):
# App
APP_NAME: str = "Stock Platform API"
DEBUG: bool = False
API_V1_PREFIX: str = "/api/v1"
# Database
DATABASE_URL: str = "postgresql+asyncpg://stock:stockpass@localhost:5432/stockdb"
# Redis
REDIS_URL: str = "redis://localhost:6379/0"
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
# Auth
SECRET_KEY: str = "change-me-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# CORS
ALLOWED_ORIGINS: str = "http://localhost:3000,http://localhost:5173"
@property
def cors_origins(self) -> List[str]:
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",")]
# Tushare (optional)
TUSHARE_TOKEN: str = ""
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()

View File

@@ -0,0 +1,37 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
)
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

22
backend/app/core/redis.py Normal file
View File

@@ -0,0 +1,22 @@
import redis.asyncio as aioredis
from app.core.config import settings
_redis_pool: aioredis.Redis | None = None
async def get_redis() -> aioredis.Redis:
global _redis_pool
if _redis_pool is None:
_redis_pool = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return _redis_pool
async def close_redis() -> None:
global _redis_pool
if _redis_pool:
await _redis_pool.aclose()
_redis_pool = None

View File

@@ -0,0 +1,37 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(subject: str | int, expires_delta: Optional[timedelta] = None) -> str:
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
)
payload = {"sub": str(subject), "exp": expire, "type": "access"}
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def create_refresh_token(subject: str | int) -> str:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
payload = {"sub": str(subject), "exp": expire, "type": "refresh"}
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def decode_token(token: str) -> Optional[dict]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None

89
backend/app/main.py Normal file
View File

@@ -0,0 +1,89 @@
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from app.core.config import settings
from app.core.database import init_db
from app.core.redis import close_redis
from app.api.v1 import api_router
from app.websocket.manager import manager
from app.services import stock_service
from app.api.deps import get_current_user
# ── background task: push heatmap every 5s ───────────────────────────────────
async def _heatmap_pusher():
while True:
try:
data = await stock_service.get_all_stocks_spot()
await manager.broadcast_all(data)
except Exception as e:
logger.warning(f"heatmap pusher error: {e}")
await asyncio.sleep(5)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up...")
await init_db()
task = asyncio.create_task(_heatmap_pusher())
yield
task.cancel()
await close_redis()
logger.info("Shut down.")
# ── app ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title=settings.APP_NAME,
docs_url="/api/docs",
redoc_url="/api/redoc",
openapi_url="/api/openapi.json",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix=settings.API_V1_PREFIX)
# ── WebSocket endpoints ───────────────────────────────────────────────────────
@app.websocket("/ws/heatmap")
async def ws_heatmap(websocket: WebSocket):
"""Subscribe to full market heatmap updates."""
await manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(websocket)
@app.websocket("/ws/quote/{symbol}")
async def ws_quote(websocket: WebSocket, symbol: str):
"""Subscribe to real-time quote updates for a single stock."""
await manager.connect(websocket, symbol)
try:
while True:
await asyncio.sleep(3)
data = await stock_service.get_stock_quote(symbol)
if data:
await manager.broadcast_quote(symbol, data)
except WebSocketDisconnect:
manager.disconnect(websocket, symbol)
@app.get("/api/health")
async def health():
return {"status": "ok"}

View File

@@ -0,0 +1,5 @@
from app.models.user import User
from app.models.watchlist import Watchlist
from app.models.alert import Alert, AlertType
__all__ = ["User", "Watchlist", "Alert", "AlertType"]

View File

@@ -0,0 +1,29 @@
from datetime import datetime
from sqlalchemy import String, Float, Boolean, ForeignKey, DateTime, func, Enum
from sqlalchemy.orm import Mapped, mapped_column, relationship
import enum
from app.core.database import Base
class AlertType(str, enum.Enum):
PRICE_ABOVE = "price_above"
PRICE_BELOW = "price_below"
CHANGE_PCT_ABOVE = "change_pct_above"
CHANGE_PCT_BELOW = "change_pct_below"
class Alert(Base):
__tablename__ = "alerts"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
symbol: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
name: Mapped[str] = mapped_column(String(50), nullable=False)
alert_type: Mapped[AlertType] = mapped_column(Enum(AlertType), nullable=False)
threshold: Mapped[float] = mapped_column(Float, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
triggered: Mapped[bool] = mapped_column(Boolean, default=False)
triggered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
user: Mapped["User"] = relationship(back_populates="alerts")

View File

@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import String, Boolean, DateTime, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.database import Base
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False)
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
hashed_password: Mapped[str] = mapped_column(String(128), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
watchlist: Mapped[list["Watchlist"]] = relationship(back_populates="user", lazy="select")
alerts: Mapped[list["Alert"]] = relationship(back_populates="user", lazy="select")

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, DateTime, func, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.database import Base
class Watchlist(Base):
__tablename__ = "watchlist"
__table_args__ = (UniqueConstraint("user_id", "symbol", name="uq_user_symbol"),)
id: Mapped[int] = mapped_column(primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
symbol: Mapped[str] = mapped_column(String(20), nullable=False)
name: Mapped[str] = mapped_column(String(50), nullable=False)
sort_order: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
user: Mapped["User"] = relationship(back_populates="watchlist")

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,47 @@
from pydantic import BaseModel, EmailStr, field_validator
import re
class RegisterRequest(BaseModel):
username: str
email: EmailStr
password: str
@field_validator("username")
@classmethod
def username_alphanumeric(cls, v: str) -> str:
if not re.match(r"^[a-zA-Z0-9_]{3,20}$", v):
raise ValueError("用户名只能包含字母、数字、下划线长度3-20位")
return v
@field_validator("password")
@classmethod
def password_min_length(cls, v: str) -> str:
if len(v) < 6:
raise ValueError("密码至少6位")
return v
class LoginRequest(BaseModel):
username: str
password: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class RefreshRequest(BaseModel):
refresh_token: str
class UserOut(BaseModel):
id: int
username: str
email: str
is_active: bool
is_admin: bool
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,91 @@
from pydantic import BaseModel
from typing import Optional
class StockQuote(BaseModel):
symbol: str
name: str
price: float
change: float
change_pct: float
open: float
high: float
low: float
prev_close: float
volume: float
amount: float
market_cap: Optional[float] = None
pe_ratio: Optional[float] = None
sector: Optional[str] = None
class StockSearchResult(BaseModel):
symbol: str
name: str
market: str
class KLineBar(BaseModel):
date: str
open: float
high: float
low: float
close: float
volume: float
amount: Optional[float] = None
change_pct: Optional[float] = None
class IntraDayBar(BaseModel):
time: str
price: float
volume: float
amount: Optional[float] = None
avg_price: Optional[float] = None
class MarketOverview(BaseModel):
index_code: str
index_name: str
current: float
change: float
change_pct: float
class SectorData(BaseModel):
sector: str
change_pct: float
stocks: list[StockQuote] = []
class WatchlistItem(BaseModel):
id: int
symbol: str
name: str
sort_order: int
model_config = {"from_attributes": True}
class WatchlistAddRequest(BaseModel):
symbol: str
name: str
class AlertCreate(BaseModel):
symbol: str
name: str
alert_type: str
threshold: float
class AlertOut(BaseModel):
id: int
symbol: str
name: str
alert_type: str
threshold: float
is_active: bool
triggered: bool
model_config = {"from_attributes": True}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,332 @@
"""
Stock data service — wraps AKShare with Redis caching.
All AKShare calls are blocking I/O; run them in a thread pool via asyncio.
"""
import asyncio
import json
from functools import lru_cache
from typing import Optional
from loguru import logger
try:
import akshare as ak
AK_AVAILABLE = True
except Exception:
AK_AVAILABLE = False
logger.warning("AKShare not available, using mock data")
# ── helpers ──────────────────────────────────────────────────────────────────
def _run_sync(fn, *args, **kwargs):
"""Run a sync blocking function in the default thread pool."""
loop = asyncio.get_running_loop()
return loop.run_in_executor(None, lambda: fn(*args, **kwargs))
# ── market overview ───────────────────────────────────────────────────────────
MAJOR_INDICES = [
("sh000001", "上证指数"),
("sz399001", "深证成指"),
("sz399006", "创业板指"),
("sh000688", "科创50"),
("sh000300", "沪深300"),
]
async def get_market_overview() -> list[dict]:
if not AK_AVAILABLE:
return _mock_market_overview()
try:
df = await _run_sync(ak.stock_zh_index_spot_em)
result = []
code_map = {code: name for code, name in MAJOR_INDICES}
for _, row in df.iterrows():
code = str(row.get("代码", ""))
if code in code_map:
result.append({
"index_code": code,
"index_name": code_map[code],
"current": float(row.get("最新价", 0)),
"change": float(row.get("涨跌额", 0)),
"change_pct": float(row.get("涨跌幅", 0)),
})
return result
except Exception as e:
logger.error(f"get_market_overview error: {e}")
return _mock_market_overview()
# ── real-time quotes (A-share spot) ──────────────────────────────────────────
async def get_all_stocks_spot() -> list[dict]:
"""All A-share real-time quotes — used for heatmap."""
if not AK_AVAILABLE:
return _mock_heatmap_data()
try:
df = await _run_sync(ak.stock_zh_a_spot_em)
result = []
for _, row in df.iterrows():
pct = float(row.get("涨跌幅", 0) or 0)
result.append({
"symbol": str(row.get("代码", "")),
"name": str(row.get("名称", "")),
"price": float(row.get("最新价", 0) or 0),
"change": float(row.get("涨跌额", 0) or 0),
"change_pct": pct,
"open": float(row.get("今开", 0) or 0),
"high": float(row.get("最高", 0) or 0),
"low": float(row.get("最低", 0) or 0),
"prev_close": float(row.get("昨收", 0) or 0),
"volume": float(row.get("成交量", 0) or 0),
"amount": float(row.get("成交额", 0) or 0),
})
return result
except Exception as e:
logger.error(f"get_all_stocks_spot error: {e}")
return _mock_heatmap_data()
async def get_stock_quote(symbol: str) -> Optional[dict]:
"""Single stock real-time quote."""
all_stocks = await get_all_stocks_spot()
for s in all_stocks:
if s["symbol"] == symbol:
return s
return None
# ── K-line data ───────────────────────────────────────────────────────────────
PERIOD_MAP = {
"daily": "daily",
"weekly": "weekly",
"monthly": "monthly",
}
async def get_kline(symbol: str, period: str = "daily", adjust: str = "qfq", limit: int = 250) -> list[dict]:
"""
period: daily | weekly | monthly
adjust: qfq (前复权) | hfq (后复权) | "" (不复权)
"""
if not AK_AVAILABLE:
return _mock_kline(limit)
ak_period = PERIOD_MAP.get(period, "daily")
try:
df = await _run_sync(
ak.stock_zh_a_hist,
symbol=symbol,
period=ak_period,
adjust=adjust,
)
df = df.tail(limit)
result = []
for _, row in df.iterrows():
result.append({
"date": str(row.get("日期", "")),
"open": float(row.get("开盘", 0)),
"high": float(row.get("最高", 0)),
"low": float(row.get("最低", 0)),
"close": float(row.get("收盘", 0)),
"volume": float(row.get("成交量", 0)),
"amount": float(row.get("成交额", 0)),
"change_pct": float(row.get("涨跌幅", 0)),
})
return result
except Exception as e:
logger.error(f"get_kline {symbol} {period} error: {e}")
return _mock_kline(limit)
async def get_intraday(symbol: str) -> list[dict]:
"""Today's minute-level data."""
if not AK_AVAILABLE:
return _mock_intraday()
try:
df = await _run_sync(ak.stock_zh_a_hist_min_em, symbol=symbol, period="1", adjust="")
result = []
for _, row in df.iterrows():
result.append({
"time": str(row.get("时间", "")),
"price": float(row.get("收盘", 0)),
"volume": float(row.get("成交量", 0)),
"amount": float(row.get("成交额", 0)),
"avg_price": float(row.get("均价", 0) or 0),
})
return result
except Exception as e:
logger.error(f"get_intraday {symbol} error: {e}")
return _mock_intraday()
async def get_five_day(symbol: str) -> list[dict]:
"""5-day minute-level data."""
if not AK_AVAILABLE:
return _mock_intraday(days=5)
try:
df = await _run_sync(ak.stock_zh_a_hist_min_em, symbol=symbol, period="1", adjust="")
result = []
for _, row in df.iterrows():
result.append({
"time": str(row.get("时间", "")),
"price": float(row.get("收盘", 0)),
"volume": float(row.get("成交量", 0)),
"amount": float(row.get("成交额", 0) or 0),
"avg_price": float(row.get("均价", 0) or 0),
})
return result[-5 * 240:]
except Exception as e:
logger.error(f"get_five_day {symbol} error: {e}")
return _mock_intraday(days=5)
# ── search ────────────────────────────────────────────────────────────────────
@lru_cache(maxsize=1)
def _get_stock_list_cached() -> list[dict]:
"""Cache the full stock list in memory (refreshed on process restart)."""
if not AK_AVAILABLE:
return []
try:
import akshare as ak
df = ak.stock_info_a_code_name()
return [{"symbol": str(r["code"]), "name": str(r["name"]), "market": "A股"} for _, r in df.iterrows()]
except Exception as e:
logger.error(f"_get_stock_list_cached error: {e}")
return []
async def search_stocks(query: str, limit: int = 20) -> list[dict]:
stock_list = await _run_sync(_get_stock_list_cached)
query = query.strip().lower()
results = [
s for s in stock_list
if query in s["symbol"].lower() or query in s["name"].lower()
]
return results[:limit]
# ── sector data for heatmap ───────────────────────────────────────────────────
async def get_sector_spot() -> list[dict]:
"""Board/sector change pct for heatmap grouping."""
if not AK_AVAILABLE:
return _mock_sectors()
try:
df = await _run_sync(ak.stock_board_industry_name_em)
result = []
for _, row in df.iterrows():
result.append({
"sector": str(row.get("板块名称", "")),
"change_pct": float(row.get("涨跌幅", 0) or 0),
"volume": float(row.get("成交量", 0) or 0),
"amount": float(row.get("成交额", 0) or 0),
})
return result
except Exception as e:
logger.error(f"get_sector_spot error: {e}")
return _mock_sectors()
# ── mock data (fallback when market is closed or AKShare unavailable) ─────────
import random
import math
from datetime import date, timedelta
def _mock_market_overview() -> list[dict]:
return [
{"index_code": "sh000001", "index_name": "上证指数", "current": 3312.46, "change": -28.84, "change_pct": -0.86},
{"index_code": "sz399001", "index_name": "深证成指", "current": 10573.99, "change": -93.16, "change_pct": -0.87},
{"index_code": "sz399006", "index_name": "创业板指", "current": 2105.37, "change": -18.42, "change_pct": -0.87},
{"index_code": "sh000688", "index_name": "科创50", "current": 968.12, "change": -9.56, "change_pct": -0.98},
{"index_code": "sh000300", "index_name": "沪深300", "current": 3843.20, "change": -30.11, "change_pct": -0.78},
]
def _mock_heatmap_data() -> list[dict]:
sectors = ["银行", "电力设备", "食品饮料", "医药生物", "电子", "汽车", "非银金融", "计算机", "有色金属", "化工"]
stocks = []
for i in range(80):
pct = round(random.uniform(-5, 5), 2)
sector = sectors[i % len(sectors)]
stocks.append({
"symbol": f"{600000 + i:06d}",
"name": f"测试股票{i+1:02d}",
"price": round(random.uniform(5, 100), 2),
"change": round(pct * 0.1, 2),
"change_pct": pct,
"open": round(random.uniform(5, 100), 2),
"high": round(random.uniform(5, 100), 2),
"low": round(random.uniform(5, 100), 2),
"prev_close": round(random.uniform(5, 100), 2),
"volume": random.randint(100000, 10000000),
"amount": random.randint(1000000, 100000000),
"sector": sector,
})
return stocks
def _mock_kline(limit: int = 250) -> list[dict]:
bars = []
price = 20.0
today = date.today()
for i in range(limit):
d = today - timedelta(days=limit - i)
pct = random.uniform(-0.05, 0.05)
close = round(price * (1 + pct), 2)
high = round(max(price, close) * random.uniform(1.0, 1.03), 2)
low = round(min(price, close) * random.uniform(0.97, 1.0), 2)
bars.append({
"date": d.isoformat(),
"open": round(price, 2),
"high": high,
"low": low,
"close": close,
"volume": random.randint(500000, 5000000),
"amount": random.randint(5000000, 50000000),
"change_pct": round(pct * 100, 2),
})
price = close
return bars
def _mock_intraday(days: int = 1) -> list[dict]:
bars = []
price = 20.0
for d in range(days):
for minute in range(240):
h = 9 + minute // 60
m = (minute % 60) + (30 if d == 0 and minute < 60 else 0)
if h == 11 and m >= 30:
continue
pct = random.uniform(-0.01, 0.01)
price = round(price * (1 + pct), 2)
bars.append({
"time": f"2026-06-0{d+1} {h:02d}:{m % 60:02d}",
"price": price,
"volume": random.randint(10000, 500000),
"amount": random.randint(100000, 5000000),
"avg_price": price,
})
return bars
def _mock_sectors() -> list[dict]:
sectors = [
"银行", "电力设备", "食品饮料", "医药生物", "电子",
"汽车", "非银金融", "计算机", "有色金属", "化工",
"机械设备", "建筑材料", "传媒", "房地产", "交通运输",
]
return [{"sector": s, "change_pct": round(random.uniform(-3, 3), 2), "volume": 0, "amount": 0}
for s in sectors]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,58 @@
"""
WebSocket connection manager — broadcasts real-time stock quotes to subscribers.
"""
import asyncio
import json
from collections import defaultdict
from fastapi import WebSocket
from loguru import logger
class ConnectionManager:
def __init__(self):
# symbol -> set of WebSocket connections
self._subs: dict[str, set[WebSocket]] = defaultdict(set)
self._all: set[WebSocket] = set()
async def connect(self, ws: WebSocket, symbol: str | None = None):
await ws.accept()
self._all.add(ws)
if symbol:
self._subs[symbol].add(ws)
logger.info(f"WS connected symbol={symbol}, total={len(self._all)}")
def disconnect(self, ws: WebSocket, symbol: str | None = None):
self._all.discard(ws)
if symbol:
self._subs[symbol].discard(ws)
else:
for s in list(self._subs.keys()):
self._subs[s].discard(ws)
logger.info(f"WS disconnected, total={len(self._all)}")
async def broadcast_quote(self, symbol: str, data: dict):
"""Send quote update to all subscribers of a specific symbol."""
message = json.dumps({"type": "quote", "symbol": symbol, "data": data})
dead = set()
for ws in list(self._subs.get(symbol, [])):
try:
await ws.send_text(message)
except Exception:
dead.add(ws)
for ws in dead:
self.disconnect(ws, symbol)
async def broadcast_all(self, data: list[dict]):
"""Broadcast heatmap snapshot to all connected clients."""
message = json.dumps({"type": "heatmap", "data": data})
dead = set()
for ws in list(self._all):
try:
await ws.send_text(message)
except Exception:
dead.add(ws)
for ws in dead:
self.disconnect(ws)
manager = ConnectionManager()

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,36 @@
import json
import asyncio
import redis as sync_redis
from celery_app.worker import app
from app.core.config import settings
from loguru import logger
def _sync_redis():
return sync_redis.from_url(settings.REDIS_URL, decode_responses=True)
@app.task(name="celery_app.tasks.market_tasks.refresh_heatmap_cache")
def refresh_heatmap_cache():
"""Pull all A-share spot quotes and push to Redis cache."""
try:
from app.services.stock_service import get_all_stocks_spot
data = asyncio.run(get_all_stocks_spot())
r = _sync_redis()
r.setex("market:heatmap", 60, json.dumps(data))
logger.info(f"Heatmap cache refreshed: {len(data)} stocks")
except Exception as e:
logger.error(f"refresh_heatmap_cache error: {e}")
@app.task(name="celery_app.tasks.market_tasks.refresh_market_overview")
def refresh_market_overview():
"""Pull major index data and push to Redis cache."""
try:
from app.services.stock_service import get_market_overview
data = asyncio.run(get_market_overview())
r = _sync_redis()
r.setex("market:overview", 120, json.dumps(data))
logger.info(f"Market overview refreshed: {len(data)} indices")
except Exception as e:
logger.error(f"refresh_market_overview error: {e}")

View File

@@ -0,0 +1,31 @@
from celery import Celery
from celery.schedules import crontab
import os
import sys
app = Celery(
"stock_worker",
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/1"),
backend=os.getenv("REDIS_URL", "redis://localhost:6379/0"),
include=["celery_app.tasks.market_tasks"],
)
app.conf.timezone = "Asia/Shanghai"
app.conf.enable_utc = False
# Windows does not support the default prefork pool (fork syscall unavailable)
if sys.platform == "win32":
app.conf.worker_pool = "solo"
app.conf.beat_schedule = {
# Refresh heatmap cache every 30s during trading hours
"refresh-heatmap-30s": {
"task": "celery_app.tasks.market_tasks.refresh_heatmap_cache",
"schedule": 30.0,
},
# Refresh index data every minute
"refresh-indices-1m": {
"task": "celery_app.tasks.market_tasks.refresh_market_overview",
"schedule": 60.0,
},
}

34
backend/requirements.txt Normal file
View File

@@ -0,0 +1,34 @@
# Web framework
fastapi==0.115.5
uvicorn[standard]==0.32.1
python-multipart==0.0.20
# Database
sqlalchemy==2.0.36
asyncpg==0.30.0
alembic==1.14.0
# Redis & Celery
redis==5.2.1
celery==5.4.0
celery[redis]==5.4.0
# Auth
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
# Stock data
akshare==1.18.64
# HTTP client
httpx==0.28.1
aiohttp==3.11.11
# Validation & settings
pydantic==2.10.3
pydantic-settings==2.7.0
# Utilities
python-dotenv==1.0.1
loguru==0.7.3
tenacity==9.0.0