功能细节优化

This commit is contained in:
2026-06-15 01:26:39 +08:00
parent e524a3589a
commit 964c17c200
33 changed files with 6990 additions and 210 deletions

View File

@@ -9,14 +9,30 @@ import datetime as dt
from contextlib import asynccontextmanager
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, Query
from fastapi import FastAPI, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from sqlalchemy.exc import SQLAlchemyError
from fastapi.staticfiles import StaticFiles
from sqlalchemy import select, func, desc
from pydantic import BaseModel
import akshare_service as svc
import redis_cache
from redis_cache import cache
import auth
from auth import get_current_user, require_auth, require_admin
import init_auth
import exceptions
from exceptions import (
BusinessException,
DataSourceException,
business_exception_handler,
validation_exception_handler,
sqlalchemy_exception_handler,
general_exception_handler
)
import config
import scheduler
import backtest as bt
@@ -37,6 +53,11 @@ import sentiment_monitor as sentiment
import event_driven as events
import financial_analysis as fin
import limit_analysis as limit_up
import watchlist_manager as wl
import position_cost as pc
import trade_calendar as cal
import data_manager as dm
import paper_trading as paper
from db import init_db, get_session
from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade,
@@ -47,8 +68,11 @@ from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
async def lifespan(app: FastAPI):
try:
init_db()
init_auth.init_default_admin()
wl.init_default_groups()
paper.ensure_default_account()
scheduler.start_scheduler()
print("[startup] db + scheduler ready")
print("[startup] db + scheduler + auth ready")
except Exception as e:
print("[startup] WARN:", repr(e)[:160])
yield
@@ -56,6 +80,12 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan)
# 注册异常处理器
app.add_exception_handler(BusinessException, business_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
app.add_exception_handler(Exception, general_exception_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -84,10 +114,113 @@ def save_watch(symbols):
json.dump(symbols, f, ensure_ascii=False)
# ============ 认证相关 API ============
class LoginRequest(BaseModel):
username: str
password: str
@app.post("/api/auth/login")
def login(req: LoginRequest, db = Depends(get_session)):
"""用户登录"""
user = auth.authenticate_user(db, req.username, req.password)
if not user:
raise exceptions.AuthException("用户名或密码错误")
access_token = auth.create_access_token(data={"sub": user.username})
return {
"ok": True,
"access_token": access_token,
"token_type": "bearer",
"username": user.username,
"is_admin": user.is_admin
}
@app.get("/api/auth/me")
def get_me(current_user = Depends(require_auth)):
"""获取当前用户信息"""
return {
"ok": True,
"username": current_user.username,
"is_admin": current_user.is_admin
}
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
@app.post("/api/auth/change-password")
def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)):
"""修改密码"""
if not auth.verify_password(req.old_password, current_user.hashed_password):
raise exceptions.AuthException("原密码错误")
current_user.hashed_password = auth.get_password_hash(req.new_password)
db.commit()
return {"ok": True, "msg": "密码修改成功"}
# ============ 用户管理 ============
class CreateUserRequest(BaseModel):
username: str
password: str
is_admin: bool = False
from models import User as UserModel
@app.get("/api/users")
def list_users(current_user = Depends(require_admin)):
with get_session() as s:
rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all()
return {"ok": True, "users": [{"id": r.id, "username": r.username,
"is_admin": r.is_admin, "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]}
@app.post("/api/users")
def create_user(req: CreateUserRequest, current_user = Depends(require_admin)):
with get_session() as s:
if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none():
return {"ok": False, "msg": "用户名已存在"}
user = UserModel(username=req.username,
hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin)
s.add(user); s.commit()
return {"ok": True, "id": user.id}
@app.delete("/api/users/{uid}")
def delete_user(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能删除自己"}
with get_session() as s:
u = s.get(UserModel, uid)
if u: s.delete(u); s.commit()
return {"ok": True}
@app.put("/api/users/{uid}/toggle_admin")
def toggle_admin(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能修改自己的权限"}
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.is_admin = not u.is_admin; s.commit()
return {"ok": True, "is_admin": u.is_admin}
@app.put("/api/users/{uid}/reset_password")
def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)):
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.hashed_password = auth.get_password_hash(req.new_password); s.commit()
return {"ok": True}
# ============ API ============
@app.get("/api/health")
def health():
return {"ok": True, "akshare": svc.AK_OK}
return {
"ok": True,
"akshare": svc.AK_OK,
"redis": cache.enabled,
"auth": True
}
@app.get("/api/indices")
@@ -106,10 +239,76 @@ def sentiment():
@app.get("/api/treemap")
def treemap(mode: str = Query("sector")):
def treemap(mode: str = Query("sector"), date: str = Query(None)):
if mode == "sector" and date:
# 从数据库读历史板块数据
try:
target = dt.date.fromisoformat(date)
except Exception:
return svc.get_treemap(mode)
with get_session() as s:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == target)
.order_by(SectorDaily.pct.desc())
).scalars().all()
if not rows:
# 找最近有数据的日期
latest = s.execute(select(func.max(SectorDaily.date))).scalar()
if latest:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == latest)
.order_by(SectorDaily.pct.desc())
).scalars().all()
target = latest
if rows:
items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows]
return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items}
return svc.get_treemap(mode)
@app.get("/api/treemap/us")
def treemap_us():
return svc.get_us_treemap()
@app.get("/api/treemap/hk")
def treemap_hk():
return svc.get_hk_treemap()
@app.get("/api/treemap/sector_stocks")
def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)):
return svc.get_sector_stocks(name, limit)
@app.get("/api/treemap/all_leaders")
def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)):
from models import SectorLeader
# 优先从数据库读
with get_session() as s:
target = None
if date:
try: target = dt.date.fromisoformat(date)
except Exception: pass
if not target:
target = s.execute(select(func.max(SectorLeader.date))).scalar()
if target:
rows = s.execute(
select(SectorLeader).where(SectorLeader.date == target)
.order_by(SectorLeader.sector, SectorLeader.rank)
).scalars().all()
if rows:
sectors = {}
for r in rows:
sectors.setdefault(r.sector, []).append({
"code": r.code, "name": r.name, "pct": r.pct,
"price": r.price, "amount": r.amount
})
return {"source": "db", "date": target.isoformat(), "sectors": sectors}
# 降级到实时
return svc.get_all_sector_leaders(top_n)
@app.get("/api/fundflow")
def fundflow():
return svc.get_fund_flow()
@@ -151,9 +350,105 @@ def watch_del(code: str):
return {"ok": True, "list": w}
# ============ 自选股分组管理 ============
class CreateGroupRequest(BaseModel):
name: str
description: str = ""
color: str = "blue"
@app.get("/api/watchlist/groups")
def list_groups():
"""获取所有分组"""
return {"ok": True, "groups": wl.get_all_groups()}
@app.post("/api/watchlist/groups")
def create_group(req: CreateGroupRequest):
"""创建新分组"""
return wl.create_group(req.name, req.description, req.color)
class UpdateGroupRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
@app.put("/api/watchlist/groups/{group_id}")
def update_group(group_id: int, req: UpdateGroupRequest):
"""更新分组信息"""
return wl.update_group(group_id, req.name, req.description, req.color)
@app.delete("/api/watchlist/groups/{group_id}")
def delete_group(group_id: int):
"""删除分组"""
return wl.delete_group(group_id)
class ReorderGroupsRequest(BaseModel):
group_ids: List[int]
@app.post("/api/watchlist/groups/reorder")
def reorder_groups(req: ReorderGroupsRequest):
"""重新排序分组"""
return wl.reorder_groups(req.group_ids)
@app.get("/api/watchlist/groups/{group_id}/stocks")
def get_group_stocks(group_id: int, with_quotes: bool = Query(True)):
"""获取分组内的股票"""
return wl.get_group_stocks(group_id, with_quotes)
class AddStockRequest(BaseModel):
code: str
note: str = ""
@app.post("/api/watchlist/groups/{group_id}/stocks")
def add_stock_to_group(group_id: int, req: AddStockRequest):
"""添加股票到分组"""
return wl.add_stock_to_group(group_id, req.code, req.note)
@app.delete("/api/watchlist/stocks/{item_id}")
def remove_stock(item_id: int):
"""从分组中移除股票"""
return wl.remove_stock_from_group(item_id)
class MoveStockRequest(BaseModel):
target_group_id: int
@app.post("/api/watchlist/stocks/{item_id}/move")
def move_stock(item_id: int, req: MoveStockRequest):
"""移动股票到另一个分组"""
return wl.move_stock_to_group(item_id, req.target_group_id)
class BatchAddRequest(BaseModel):
codes: List[str]
@app.post("/api/watchlist/groups/{group_id}/stocks/batch")
def batch_add_stocks(group_id: int, req: BatchAddRequest):
"""批量添加股票"""
return wl.batch_add_stocks(group_id, req.codes)
class UpdateNoteRequest(BaseModel):
note: str
@app.put("/api/watchlist/stocks/{item_id}/note")
def update_stock_note(item_id: int, req: UpdateNoteRequest):
"""更新股票备注"""
return wl.update_stock_note(item_id, req.note)
class ReorderStocksRequest(BaseModel):
item_ids: List[int]
@app.post("/api/watchlist/stocks/reorder")
def reorder_stocks(req: ReorderStocksRequest):
"""重新排序股票"""
return wl.reorder_stocks(req.item_ids)
@app.get("/api/watchlist/search")
def search_stocks(keyword: str = Query(..., min_length=1)):
"""跨分组搜索股票"""
return {"ok": True, "results": wl.search_stocks_across_groups(keyword)}
# ============ 数据中台 ============
@app.get("/api/admin/status")
def admin_status():
def admin_status(current_user = Depends(require_admin)):
counts, last_dates = {}, {}
with get_session() as s:
for label, model in [("securities", Security), ("quotes_daily", DailyQuote),
@@ -175,14 +470,14 @@ def admin_status():
@app.post("/api/admin/ingest")
def admin_ingest():
def admin_ingest(current_user = Depends(require_admin)):
if scheduler.is_running():
return {"started": False, "msg": "已有入库任务在执行"}
return scheduler.trigger_async()
@app.post("/api/admin/ingest_all")
def admin_ingest_all():
def admin_ingest_all(current_user = Depends(require_admin)):
return scheduler.trigger_all_async()
@@ -488,6 +783,16 @@ def ai_today():
return ai.today_strategy()
@app.get("/api/ai/trend_analysis")
def ai_trend_analysis(
symbol: str = Query(...),
date: str = Query(""),
period: str = Query("daily")
):
"""走势分析右键K线条形时调用分析暴涨/暴跌原因"""
return ai.trend_analysis(symbol, date, period)
# ============ 可回溯:信号历史胜率 + 实测准确率 ============
@app.get("/api/ai/signal_stats")
def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)):
@@ -582,6 +887,144 @@ def portfolio_equity():
return pf.equity_curve()
# ============ 持仓成本可视化增强 ============
@app.get("/api/portfolio/cost_line/{code}")
def get_cost_line(code: str):
"""获取个股持仓成本线用于K线图标注"""
return pc.get_position_cost_lines(code)
@app.get("/api/portfolio/cost_distribution")
def get_cost_distribution():
"""获取持仓成本分布(盈亏区间图)"""
return pc.get_position_cost_distribution()
class EstimateCostRequest(BaseModel):
code: str
price: float
qty: int
side: str = "buy"
@app.post("/api/portfolio/estimate_cost")
def estimate_cost(req: EstimateCostRequest):
"""估算交易成本(下单前预估)"""
return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side)
@app.get("/api/portfolio/cost_breakdown/{code}")
def get_cost_breakdown(code: str):
"""获取持仓的详细成本拆解"""
return pc.get_cost_breakdown_for_position(code)
# ============ 交易日历与关键事件 ============
@app.get("/api/calendar/events")
def calendar_events(days: int = Query(30, ge=7, le=90)):
"""获取所有即将到来的关键事件(综合视图)"""
return cal.get_all_upcoming_events(days)
@app.get("/api/calendar/dividends")
def calendar_dividends(days: int = Query(30, ge=7, le=90)):
"""除权除息日历(持仓股优先)"""
return cal.get_upcoming_dividends(days)
@app.get("/api/calendar/unlock")
def calendar_unlock(days: int = Query(90, ge=7, le=180)):
"""限售解禁日历"""
return cal.get_unlock_calendar(days)
@app.get("/api/calendar/earnings")
def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)):
"""财报披露日历"""
return cal.get_earnings_calendar(days, holding_only)
@app.post("/api/calendar/check_alerts")
def calendar_check_alerts(current_user = Depends(require_admin)):
"""手动触发日历事件预警推送"""
return cal.check_and_push_calendar_alerts()
# ============ 数据修正与回填增强 ============
class UpdateQuoteRequest(BaseModel):
open: Optional[float] = None
high: Optional[float] = None
low: Optional[float] = None
close: Optional[float] = None
volume: Optional[int] = None
amount: Optional[float] = None
@app.delete("/api/data/quote/{code}/{date}")
def delete_quote(code: str, date: str, current_user = Depends(require_admin)):
"""删除指定股票指定日期的日线"""
return dm.delete_quote(code, date)
@app.put("/api/data/quote/{code}/{date}")
def update_quote(code: str, date: str, req: UpdateQuoteRequest,
current_user = Depends(require_admin)):
"""修正指定日线数据"""
return dm.update_quote(code, date, req.model_dump(exclude_none=True))
class DeleteRangeRequest(BaseModel):
start: str
end: str
@app.delete("/api/data/quotes/{code}/range")
def delete_quotes_range(code: str, req: DeleteRangeRequest,
current_user = Depends(require_admin)):
"""删除指定股票日期范围内的日线数据"""
return dm.delete_quotes_range(code, req.start, req.end)
@app.post("/api/data/refetch/{code}")
def refetch_quote(code: str, days: int = Query(60, ge=5, le=500),
current_user = Depends(require_admin)):
"""重新抓取指定股票日线(覆盖更新)"""
return dm.refetch_quote(code, days)
@app.get("/api/data/integrity")
def check_integrity(days: int = Query(30, ge=7, le=90),
current_user = Depends(require_admin)):
"""数据完整性检查"""
return dm.check_data_integrity(days=days)
@app.post("/api/data/auto_fix")
def auto_fix_missing(limit: int = Query(50, ge=10, le=200),
current_user = Depends(require_admin)):
"""自动补齐缺失数据"""
t = __import__("threading").Thread(
target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True
)
t.start()
return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"}
@app.get("/api/data/refill_progress")
def refill_progress(task_id: str = Query("default")):
"""获取回填进度"""
return dm.get_refill_progress(task_id)
@app.post("/api/data/refill_resume")
def refill_resume(days: int = Query(250, ge=30, le=1000),
task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""带断点续传的全市场回填(后台执行)"""
import threading
t = threading.Thread(
target=dm.start_refill_with_resume,
kwargs={"days": days, "task_id": task_id},
daemon=True
)
t.start()
return {"ok": True, "msg": f"已启动断点续传回填,天数={days}任务ID={task_id}"}
@app.delete("/api/data/refill_progress")
def clear_refill_progress(task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""清除回填进度(从头开始)"""
return dm.clear_refill_progress(task_id)
@app.get("/api/data/quality_report")
def data_quality_report(current_user = Depends(require_admin)):
"""数据质量报告"""
return dm.get_data_quality_report()
@app.get("/api/portfolio/attribution")
def portfolio_attribution():
"""持仓归因分析"""
@@ -761,6 +1204,22 @@ def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5,
"""涨停敢死队排行"""
return limit_up.get_limit_squad_rankings(days, min_limits)
@app.get("/api/limit/consecutive_calendar")
def consecutive_calendar(days: int = Query(60, ge=20, le=120)):
"""连板日历:记录连板历史,分析几进几出规律"""
return limit_up.get_consecutive_calendar(days)
@app.get("/api/limit/post_break")
def post_break_performance(days: int = Query(90, ge=30, le=180)):
"""炸板后 1/3/5 日走势统计"""
return limit_up.analyze_post_break_performance(days)
@app.get("/api/limit/reasons")
def limit_reasons(date: Optional[str] = None):
"""涨停原因分类(情绪/题材/业绩/政策等)"""
d = dt.date.fromisoformat(date) if date else None
return limit_up.classify_limit_reasons(d)
# ============ 推送通知 ============
@app.get("/api/notify/status")
@@ -1142,6 +1601,51 @@ def delete_selector_alert(aid: int):
return {"ok": True}
# ============ 模拟盘 ============
class PaperAccountIn(BaseModel):
name: str
initial_cash: float = 1_000_000.0
@app.get("/api/paper/accounts")
def paper_list_accounts():
return {"ok": True, "accounts": paper.list_accounts()}
@app.post("/api/paper/accounts")
def paper_create_account(req: PaperAccountIn):
return paper.create_account(req.name, req.initial_cash)
@app.post("/api/paper/accounts/{account_id}/reset")
def paper_reset_account(account_id: int, initial_cash: Optional[float] = None):
return paper.reset_account(account_id, initial_cash)
class PaperOrderIn(BaseModel):
code: str
side: str # buy / sell
qty: int
price: Optional[float] = None
reason: str = ""
@app.post("/api/paper/accounts/{account_id}/order")
def paper_place_order(account_id: int, req: PaperOrderIn):
return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason)
@app.get("/api/paper/accounts/{account_id}/portfolio")
def paper_get_portfolio(account_id: int):
return paper.get_portfolio(account_id)
@app.get("/api/paper/accounts/{account_id}/trades")
def paper_get_trades(account_id: int, limit: int = Query(100, le=500)):
return {"ok": True, "trades": paper.get_trades(account_id, limit)}
# ============ 静态前端 ============
FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype")
if os.path.isdir(FRONTEND_DIR):