Files
stock_cursor_v0/backend/main.py
2026-06-15 01:26:39 +08:00

1658 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Blackdata股票终端 — FastAPI 后端入口。
- /api/* : 数据接口(基于 AkShare带缓存与降级
- / : 托管前端原型prototype 目录)
"""
import os
import json
import datetime as dt
from contextlib import asynccontextmanager
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from sqlalchemy.exc import SQLAlchemyError
from fastapi.staticfiles import StaticFiles
from sqlalchemy import select, func, desc
from pydantic import BaseModel
import akshare_service as svc
import redis_cache
from redis_cache import cache
import auth
from auth import get_current_user, require_auth, require_admin
import init_auth
import exceptions
from exceptions import (
BusinessException,
DataSourceException,
business_exception_handler,
validation_exception_handler,
sqlalchemy_exception_handler,
general_exception_handler
)
import config
import scheduler
import backtest as bt
import backtest_advanced as bta
import ai
import signals as sig
import report as rpt
import portfolio as pf
import llm
import alerts as al
import notifier
import intraday_radar as radar
import sector_rotation as sector
import smart_selector as selector
import attribution_analysis as attrib
import ai_chat
import sentiment_monitor as sentiment
import event_driven as events
import financial_analysis as fin
import limit_analysis as limit_up
import watchlist_manager as wl
import position_cost as pc
import trade_calendar as cal
import data_manager as dm
import paper_trading as paper
from db import init_db, get_session
from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade,
AlertRule, AlertEvent, SelectorStrategy, SelectorAlert)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
init_db()
init_auth.init_default_admin()
wl.init_default_groups()
paper.ensure_default_account()
scheduler.start_scheduler()
print("[startup] db + scheduler + auth ready")
except Exception as e:
print("[startup] WARN:", repr(e)[:160])
yield
app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan)
# 注册异常处理器
app.add_exception_handler(BusinessException, business_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
app.add_exception_handler(Exception, general_exception_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# 自选股本地存储
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
WATCH_FILE = os.path.join(BASE_DIR, "watchlist.json")
DEFAULT_WATCH = ["600519", "300750", "002594", "688981", "300059", "601012"]
def load_watch():
if os.path.exists(WATCH_FILE):
try:
with open(WATCH_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
pass
return DEFAULT_WATCH
def save_watch(symbols):
with open(WATCH_FILE, "w", encoding="utf-8") as f:
json.dump(symbols, f, ensure_ascii=False)
# ============ 认证相关 API ============
class LoginRequest(BaseModel):
username: str
password: str
@app.post("/api/auth/login")
def login(req: LoginRequest, db = Depends(get_session)):
"""用户登录"""
user = auth.authenticate_user(db, req.username, req.password)
if not user:
raise exceptions.AuthException("用户名或密码错误")
access_token = auth.create_access_token(data={"sub": user.username})
return {
"ok": True,
"access_token": access_token,
"token_type": "bearer",
"username": user.username,
"is_admin": user.is_admin
}
@app.get("/api/auth/me")
def get_me(current_user = Depends(require_auth)):
"""获取当前用户信息"""
return {
"ok": True,
"username": current_user.username,
"is_admin": current_user.is_admin
}
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
@app.post("/api/auth/change-password")
def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)):
"""修改密码"""
if not auth.verify_password(req.old_password, current_user.hashed_password):
raise exceptions.AuthException("原密码错误")
current_user.hashed_password = auth.get_password_hash(req.new_password)
db.commit()
return {"ok": True, "msg": "密码修改成功"}
# ============ 用户管理 ============
class CreateUserRequest(BaseModel):
username: str
password: str
is_admin: bool = False
from models import User as UserModel
@app.get("/api/users")
def list_users(current_user = Depends(require_admin)):
with get_session() as s:
rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all()
return {"ok": True, "users": [{"id": r.id, "username": r.username,
"is_admin": r.is_admin, "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]}
@app.post("/api/users")
def create_user(req: CreateUserRequest, current_user = Depends(require_admin)):
with get_session() as s:
if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none():
return {"ok": False, "msg": "用户名已存在"}
user = UserModel(username=req.username,
hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin)
s.add(user); s.commit()
return {"ok": True, "id": user.id}
@app.delete("/api/users/{uid}")
def delete_user(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能删除自己"}
with get_session() as s:
u = s.get(UserModel, uid)
if u: s.delete(u); s.commit()
return {"ok": True}
@app.put("/api/users/{uid}/toggle_admin")
def toggle_admin(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能修改自己的权限"}
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.is_admin = not u.is_admin; s.commit()
return {"ok": True, "is_admin": u.is_admin}
@app.put("/api/users/{uid}/reset_password")
def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)):
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.hashed_password = auth.get_password_hash(req.new_password); s.commit()
return {"ok": True}
# ============ API ============
@app.get("/api/health")
def health():
return {
"ok": True,
"akshare": svc.AK_OK,
"redis": cache.enabled,
"auth": True
}
@app.get("/api/indices")
def indices():
return svc.get_indices()
@app.get("/api/kline")
def kline(symbol: str = Query("600519"), days: int = Query(120, ge=20, le=500)):
return svc.get_kline(symbol, days)
@app.get("/api/sentiment")
def sentiment():
return svc.get_sentiment()
@app.get("/api/treemap")
def treemap(mode: str = Query("sector"), date: str = Query(None)):
if mode == "sector" and date:
# 从数据库读历史板块数据
try:
target = dt.date.fromisoformat(date)
except Exception:
return svc.get_treemap(mode)
with get_session() as s:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == target)
.order_by(SectorDaily.pct.desc())
).scalars().all()
if not rows:
# 找最近有数据的日期
latest = s.execute(select(func.max(SectorDaily.date))).scalar()
if latest:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == latest)
.order_by(SectorDaily.pct.desc())
).scalars().all()
target = latest
if rows:
items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows]
return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items}
return svc.get_treemap(mode)
@app.get("/api/treemap/us")
def treemap_us():
return svc.get_us_treemap()
@app.get("/api/treemap/hk")
def treemap_hk():
return svc.get_hk_treemap()
@app.get("/api/treemap/sector_stocks")
def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)):
return svc.get_sector_stocks(name, limit)
@app.get("/api/treemap/all_leaders")
def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)):
from models import SectorLeader
# 优先从数据库读
with get_session() as s:
target = None
if date:
try: target = dt.date.fromisoformat(date)
except Exception: pass
if not target:
target = s.execute(select(func.max(SectorLeader.date))).scalar()
if target:
rows = s.execute(
select(SectorLeader).where(SectorLeader.date == target)
.order_by(SectorLeader.sector, SectorLeader.rank)
).scalars().all()
if rows:
sectors = {}
for r in rows:
sectors.setdefault(r.sector, []).append({
"code": r.code, "name": r.name, "pct": r.pct,
"price": r.price, "amount": r.amount
})
return {"source": "db", "date": target.isoformat(), "sectors": sectors}
# 降级到实时
return svc.get_all_sector_leaders(top_n)
@app.get("/api/fundflow")
def fundflow():
return svc.get_fund_flow()
@app.get("/api/hot/stocks")
def hot_stocks():
return svc.get_hot_stocks()
@app.get("/api/hot/sectors")
def hot_sectors():
return svc.get_industry_boards()
@app.get("/api/dragon")
def dragon():
return svc.get_dragon_tiger()
@app.get("/api/watchlist")
def watchlist():
return svc.get_watchlist(load_watch())
@app.post("/api/watchlist/{code}")
def watch_add(code: str):
w = load_watch()
if code not in w:
w.append(code)
save_watch(w)
return {"ok": True, "list": w}
@app.delete("/api/watchlist/{code}")
def watch_del(code: str):
w = [c for c in load_watch() if c != code]
save_watch(w)
return {"ok": True, "list": w}
# ============ 自选股分组管理 ============
class CreateGroupRequest(BaseModel):
name: str
description: str = ""
color: str = "blue"
@app.get("/api/watchlist/groups")
def list_groups():
"""获取所有分组"""
return {"ok": True, "groups": wl.get_all_groups()}
@app.post("/api/watchlist/groups")
def create_group(req: CreateGroupRequest):
"""创建新分组"""
return wl.create_group(req.name, req.description, req.color)
class UpdateGroupRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
@app.put("/api/watchlist/groups/{group_id}")
def update_group(group_id: int, req: UpdateGroupRequest):
"""更新分组信息"""
return wl.update_group(group_id, req.name, req.description, req.color)
@app.delete("/api/watchlist/groups/{group_id}")
def delete_group(group_id: int):
"""删除分组"""
return wl.delete_group(group_id)
class ReorderGroupsRequest(BaseModel):
group_ids: List[int]
@app.post("/api/watchlist/groups/reorder")
def reorder_groups(req: ReorderGroupsRequest):
"""重新排序分组"""
return wl.reorder_groups(req.group_ids)
@app.get("/api/watchlist/groups/{group_id}/stocks")
def get_group_stocks(group_id: int, with_quotes: bool = Query(True)):
"""获取分组内的股票"""
return wl.get_group_stocks(group_id, with_quotes)
class AddStockRequest(BaseModel):
code: str
note: str = ""
@app.post("/api/watchlist/groups/{group_id}/stocks")
def add_stock_to_group(group_id: int, req: AddStockRequest):
"""添加股票到分组"""
return wl.add_stock_to_group(group_id, req.code, req.note)
@app.delete("/api/watchlist/stocks/{item_id}")
def remove_stock(item_id: int):
"""从分组中移除股票"""
return wl.remove_stock_from_group(item_id)
class MoveStockRequest(BaseModel):
target_group_id: int
@app.post("/api/watchlist/stocks/{item_id}/move")
def move_stock(item_id: int, req: MoveStockRequest):
"""移动股票到另一个分组"""
return wl.move_stock_to_group(item_id, req.target_group_id)
class BatchAddRequest(BaseModel):
codes: List[str]
@app.post("/api/watchlist/groups/{group_id}/stocks/batch")
def batch_add_stocks(group_id: int, req: BatchAddRequest):
"""批量添加股票"""
return wl.batch_add_stocks(group_id, req.codes)
class UpdateNoteRequest(BaseModel):
note: str
@app.put("/api/watchlist/stocks/{item_id}/note")
def update_stock_note(item_id: int, req: UpdateNoteRequest):
"""更新股票备注"""
return wl.update_stock_note(item_id, req.note)
class ReorderStocksRequest(BaseModel):
item_ids: List[int]
@app.post("/api/watchlist/stocks/reorder")
def reorder_stocks(req: ReorderStocksRequest):
"""重新排序股票"""
return wl.reorder_stocks(req.item_ids)
@app.get("/api/watchlist/search")
def search_stocks(keyword: str = Query(..., min_length=1)):
"""跨分组搜索股票"""
return {"ok": True, "results": wl.search_stocks_across_groups(keyword)}
# ============ 数据中台 ============
@app.get("/api/admin/status")
def admin_status(current_user = Depends(require_admin)):
counts, last_dates = {}, {}
with get_session() as s:
for label, model in [("securities", Security), ("quotes_daily", DailyQuote),
("index_daily", IndexDaily), ("sector_daily", SectorDaily),
("fund_flow_daily", FundFlowDaily), ("sentiment_daily", SentimentDaily),
("dragon_tiger", DragonTiger)]:
counts[label] = s.execute(select(func.count()).select_from(model)).scalar() or 0
if hasattr(model, "date"):
d = s.execute(select(func.max(model.date))).scalar()
last_dates[label] = d.isoformat() if d else None
jobs = s.execute(select(JobRun).order_by(desc(JobRun.id)).limit(8)).scalars().all()
job_list = [{"id": j.id, "job": j.job, "status": j.status,
"started": j.started_at.strftime("%m-%d %H:%M:%S") if j.started_at else "",
"finished": j.finished_at.strftime("%H:%M:%S") if j.finished_at else "",
"message": j.message[:200]} for j in jobs]
return {"counts": counts, "last_dates": last_dates, "jobs": job_list,
"running": scheduler.is_running(), "universe": config.DEFAULT_UNIVERSE,
"schedule": f"周一至周五 {config.INGEST_HOUR:02d}:{config.INGEST_MINUTE:02d}"}
@app.post("/api/admin/ingest")
def admin_ingest(current_user = Depends(require_admin)):
if scheduler.is_running():
return {"started": False, "msg": "已有入库任务在执行"}
return scheduler.trigger_async()
@app.post("/api/admin/ingest_all")
def admin_ingest_all(current_user = Depends(require_admin)):
return scheduler.trigger_all_async()
@app.get("/api/db/kline")
def db_kline(symbol: str = Query("600519"), days: int = Query(250, ge=20, le=1000)):
with get_session() as s:
rows = s.execute(
select(DailyQuote).where(DailyQuote.code == symbol)
.order_by(DailyQuote.date.desc()).limit(days)
).scalars().all()
rows = list(reversed(rows))
if not rows:
return {"source": "db", "empty": True, "symbol": symbol, "dates": [], "ohlc": [], "vols": []}
return {"source": "db", "symbol": symbol,
"dates": [r.date.strftime("%m/%d") for r in rows],
"ohlc": [[r.open, r.close, r.low, r.high] for r in rows],
"vols": [r.volume for r in rows]}
@app.get("/api/db/sentiment_history")
def db_sentiment_history(days: int = Query(60, ge=5, le=365)):
with get_session() as s:
rows = s.execute(select(SentimentDaily).order_by(SentimentDaily.date.desc()).limit(days)).scalars().all()
rows = list(reversed(rows))
return {"dates": [r.date.isoformat() for r in rows],
"up": [r.up for r in rows], "down": [r.down for r in rows],
"limit_up": [r.limit_up for r in rows]}
@app.get("/api/review/daily")
def review_daily(date: str = Query(None)):
with get_session() as s:
if date:
d = dt.date.fromisoformat(date)
else:
d = s.execute(select(func.max(SectorDaily.date))).scalar()
if not d:
return {"ok": False, "msg": "暂无入库数据,请先在数据中台执行入库"}
sectors = s.execute(select(SectorDaily).where(SectorDaily.date == d).order_by(SectorDaily.pct.desc())).scalars().all()
flows = s.execute(select(FundFlowDaily).where(FundFlowDaily.date == d).order_by(FundFlowDaily.net.desc())).scalars().all()
senti = s.execute(select(SentimentDaily).where(SentimentDaily.date == d)).scalar_one_or_none()
lhb = s.execute(select(DragonTiger).where(DragonTiger.date == d).order_by(DragonTiger.net.desc()).limit(10)).scalars().all()
top_sec = [{"name": x.name, "pct": x.pct} for x in sectors[:8]]
bot_sec = [{"name": x.name, "pct": x.pct} for x in sectors[-5:]]
inflow = [{"name": x.name, "net": x.net} for x in flows[:8]]
outflow = [{"name": x.name, "net": x.net} for x in flows[-5:][::-1]]
senti_d = ({"up": senti.up, "down": senti.down, "limit_up": senti.limit_up,
"limit_down": senti.limit_down} if senti else None)
summary = _gen_review_text(d, senti_d, top_sec, inflow)
return {"ok": True, "date": d.isoformat(), "sentiment": senti_d,
"top_sectors": top_sec, "weak_sectors": bot_sec,
"inflow": inflow, "outflow": outflow,
"dragon": [{"name": x.name, "code": x.code, "net": x.net, "pct": x.pct} for x in lhb],
"summary": summary}
def _gen_review_text(d, senti, top_sec, inflow):
parts = [f"{d.isoformat()} 复盘】"]
if senti:
tone = "情绪偏暖" if senti["up"] > senti["down"] else "情绪偏弱"
parts.append(f"全市场上涨 {senti['up']} 家、下跌 {senti['down']} 家,涨停 {senti['limit_up']} 家、跌停 {senti['limit_down']} 家,{tone}")
if top_sec:
names = "".join(x["name"] for x in top_sec[:3])
parts.append(f"领涨板块:{names}")
if inflow:
names = "".join(x["name"] for x in inflow[:3])
parts.append(f"主力净流入居前:{names}")
parts.append("以上为基于入库数据的自动统计AI 智能点评将在 AI 分析模块接入大模型后生成。")
return " ".join(parts)
@app.get("/api/backtest")
def backtest_api(symbol: str = Query("600519"), fast: int = Query(5, ge=2, le=60),
slow: int = Query(20, ge=5, le=250)):
if fast >= slow:
return {"ok": False, "msg": "快线周期需小于慢线周期"}
return bt.run_backtest(symbol, fast, slow)
# ============ 增强回测 ============
class BacktestParams(BaseModel):
symbol: str
strategy: str = "ma" # ma, multi_factor
fast: int = 5
slow: int = 20
position_size: float = 1.0
stop_loss: float = 0.0
take_profit: float = 0.0
initial_capital: float = 100000.0
commission: float = 0.0005
@app.post("/api/backtest/advanced")
def backtest_advanced(params: BacktestParams):
"""增强回测"""
if params.strategy == "ma":
strategy = bta.MAStrategy(
fast=params.fast,
slow=params.slow,
position_size=params.position_size,
stop_loss=params.stop_loss,
take_profit=params.take_profit
)
elif params.strategy == "multi_factor":
strategy = bta.MultiFactorStrategy(position_size=params.position_size)
else:
return {"ok": False, "msg": "不支持的策略类型"}
return bta.run_advanced_backtest(
symbol=params.symbol,
strategy=strategy,
initial_capital=params.initial_capital,
commission=params.commission
)
class OptimizeParams(BaseModel):
symbol: str
strategy: str = "ma"
fast_range: List[int] = [3, 5, 10, 15]
slow_range: List[int] = [20, 30, 60]
metric: str = "sharpe_ratio"
@app.post("/api/backtest/optimize")
def backtest_optimize(params: OptimizeParams):
"""参数优化"""
param_grid = {
"fast": params.fast_range,
"slow": params.slow_range
}
results = bta.optimize_parameters(
symbol=params.symbol,
param_grid=param_grid,
strategy_class=bta.MAStrategy,
metric=params.metric
)
return {
"ok": True,
"symbol": params.symbol,
"metric": params.metric,
"results": results[:20] # 返回前20个最优结果
}
class CompareParams(BaseModel):
symbol: str
strategies: List[Dict[str, Any]]
@app.post("/api/backtest/compare")
def backtest_compare(params: CompareParams):
"""策略对比"""
strategies = []
for s in params.strategies:
if s["type"] == "ma":
strategies.append(bta.MAStrategy(
fast=s.get("fast", 5),
slow=s.get("slow", 20),
stop_loss=s.get("stop_loss", 0),
take_profit=s.get("take_profit", 0)
))
elif s["type"] == "multi_factor":
strategies.append(bta.MultiFactorStrategy())
return bta.compare_strategies(params.symbol, strategies)
# ============ 全市场选股 ============
STRATEGIES = {
"surge": "最近暴涨5日涨幅≥20%",
"plunge": "最近暴跌5日跌幅≥15%",
"dip": "超跌抄底60日分位≤20%且当日企稳)",
"breakout": "突破走强逼近60日新高",
"ma_bull": "均线多头MA5>10>20",
"volume": "放量上攻量比≥2且上涨",
"macd_gold": "MACD金叉",
"strong": "强势连涨≥3日连阳",
}
@app.get("/api/screen/strategies")
def screen_strategies():
return {"list": [{"id": k, "name": v} for k, v in STRATEGIES.items()]}
@app.get("/api/screen")
def screen(strategy: str = Query("surge"), limit: int = Query(60, ge=10, le=300),
min_amount: float = Query(0.0)):
M = StockMetric
q = select(M)
order = M.ret5.desc()
if strategy == "surge":
q = q.where(M.ret5 >= 20)
elif strategy == "plunge":
q = q.where(M.ret5 <= -15); order = M.ret5.asc()
elif strategy == "dip":
q = q.where(M.pos60 <= 0.2, M.pct > 0); order = M.pos60.asc()
elif strategy == "breakout":
q = q.where(M.pos60 >= 0.95, M.pct > 0); order = M.ret20.desc()
elif strategy == "ma_bull":
q = q.where(M.ma_bull.is_(True)); order = M.ret20.desc()
elif strategy == "volume":
q = q.where(M.vol_ratio >= 2, M.pct > 0); order = M.vol_ratio.desc()
elif strategy == "macd_gold":
q = q.where(M.macd_gold.is_(True)); order = M.ret5.desc()
elif strategy == "strong":
q = q.where(M.up_streak >= 3); order = M.up_streak.desc()
if min_amount > 0:
q = q.where(M.amount >= min_amount)
q = q.order_by(order).limit(limit)
with get_session() as s:
rows = s.execute(q).scalars().all()
total = s.execute(select(func.count()).select_from(M)).scalar() or 0
return {"strategy": strategy, "name": STRATEGIES.get(strategy, strategy), "pool_size": total,
"count": len(rows), "list": [{
"code": r.code, "name": r.name, "close": r.close, "pct": r.pct,
"ret5": r.ret5, "ret20": r.ret20, "vol_ratio": r.vol_ratio,
"rsi14": r.rsi14, "pos60": round(r.pos60 * 100, 1), "amount": r.amount,
"up_streak": r.up_streak} for r in rows]}
@app.get("/api/securities/search")
def securities_search(q: str = Query("", min_length=0), limit: int = Query(15, le=50)):
with get_session() as s:
stmt = select(Security)
if q:
stmt = stmt.where((Security.code.like(f"{q}%")) | (Security.name.like(f"%{q}%")))
rows = s.execute(stmt.limit(limit)).scalars().all()
return {"list": [{"code": r.code, "name": r.name} for r in rows]}
# ============ 个股复盘K线 + 买卖点 + 回放) ============
def _ma_list(close, n):
out = [None] * len(close)
for i in range(len(close)):
if i >= n - 1:
out[i] = round(sum(close[i - n + 1:i + 1]) / n, 3)
return out
@app.get("/api/review/stock")
def review_stock(symbol: str = Query("600519"), days: int = Query(250, ge=40, le=1000),
fast: int = Query(5), slow: int = Query(20)):
with get_session() as s:
rows = s.execute(
select(DailyQuote).where(DailyQuote.code == symbol)
.order_by(DailyQuote.date.desc()).limit(days)
).scalars().all()
sec = s.get(Security, symbol)
rows = list(reversed(rows))
if not rows:
return {"ok": False, "msg": "该股票库内无日线,请先在数据中台入库该股或执行全市场回填", "symbol": symbol}
dates = [r.date.strftime("%y/%m/%d") for r in rows]
ohlc = [[r.open, r.close, r.low, r.high] for r in rows]
vols = [r.volume for r in rows]
close = [r.close for r in rows]
maf, mas = _ma_list(close, fast), _ma_list(close, slow)
signals = []
for i in range(1, len(close)):
if maf[i] is None or mas[i] is None or maf[i - 1] is None or mas[i - 1] is None:
continue
if maf[i - 1] <= mas[i - 1] and maf[i] > mas[i]:
signals.append({"idx": i, "date": dates[i], "price": close[i], "type": "buy"})
elif maf[i - 1] >= mas[i - 1] and maf[i] < mas[i]:
signals.append({"idx": i, "date": dates[i], "price": close[i], "type": "sell"})
# 区间统计
hi = max(r.high for r in rows); lo = min(r.low for r in rows)
period_ret = round((close[-1] / close[0] - 1) * 100, 2)
return {"ok": True, "symbol": symbol, "name": sec.name if sec else symbol,
"dates": dates, "ohlc": ohlc, "vols": vols,
"ma_fast": maf, "ma_slow": mas, "fast": fast, "slow": slow,
"signals": signals,
"stats": {"period_return": period_ret, "high": hi, "low": lo,
"start": dates[0], "end": dates[-1], "bars": len(rows)}}
# ============ AI 分析 ============
@app.get("/api/ai/status")
def ai_status():
return {"enabled": llm.enabled(), "model": config.LLM_MODEL if llm.enabled() else None}
@app.get("/api/ai/review_daily")
def ai_review_daily(date: str = Query(None)):
return ai.review_daily_comment(date)
@app.get("/api/ai/diagnose")
def ai_diagnose(symbol: str = Query("600519")):
return ai.diagnose(symbol)
@app.get("/api/ai/today")
def ai_today():
return ai.today_strategy()
@app.get("/api/ai/trend_analysis")
def ai_trend_analysis(
symbol: str = Query(...),
date: str = Query(""),
period: str = Query("daily")
):
"""走势分析右键K线条形时调用分析暴涨/暴跌原因"""
return ai.trend_analysis(symbol, date, period)
# ============ 可回溯:信号历史胜率 + 实测准确率 ============
@app.get("/api/ai/signal_stats")
def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)):
return {"horizon": horizon, "stats": sig.get_stats(horizon)}
@app.post("/api/ai/signal_stats/compute")
def ai_signal_stats_compute(sample: int = Query(500, ge=50, le=4000), horizon: int = Query(5, ge=1, le=20)):
return scheduler.trigger_signal_stats_async(sample, horizon)
@app.get("/api/ai/accuracy")
def ai_accuracy():
return sig.accuracy()
@app.post("/api/ai/accuracy/verify")
def ai_accuracy_verify():
return sig.verify_predictions()
# ============ AI 自动复盘日报 ============
@app.get("/api/report/daily")
def report_daily(date: str = Query(None)):
return rpt.get_by_date(date) if date else rpt.latest()
@app.get("/api/report/history")
def report_history(limit: int = Query(30, ge=1, le=120)):
return rpt.history(limit)
@app.post("/api/report/generate")
def report_generate(date: str = Query(None), push: bool = Query(False)):
return rpt.generate(date, push=push)
# ============ 交易日志 & 组合 ============
class TradeIn(BaseModel):
code: str
name: str = ""
side: str = "buy"
price: float
qty: int
fee: float = 0.0
date: str = ""
reason: str = ""
emotion: str = ""
@app.get("/api/trades")
def list_trades():
with get_session() as s:
rows = s.execute(select(Trade).order_by(Trade.date.desc(), Trade.id.desc())).scalars().all()
names = {}
return {"list": [{"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
"reason": t.reason, "emotion": t.emotion} for t in rows]}
@app.post("/api/trades")
def add_trade(t: TradeIn):
d = dt.date.fromisoformat(t.date) if t.date else dt.date.today()
name = t.name
if not name:
with get_session() as s:
sec = s.get(Security, t.code)
name = sec.name if sec else t.code
with get_session() as s:
row = Trade(date=d, code=t.code, name=name, side=t.side, price=t.price,
qty=t.qty, fee=t.fee, reason=t.reason, emotion=t.emotion)
s.add(row); s.commit()
return {"ok": True, "id": row.id}
@app.delete("/api/trades/{tid}")
def del_trade(tid: int):
with get_session() as s:
row = s.get(Trade, tid)
if row:
s.delete(row); s.commit()
return {"ok": True}
@app.get("/api/portfolio")
def get_portfolio():
return pf.compute()
@app.get("/api/portfolio/equity")
def portfolio_equity():
return pf.equity_curve()
# ============ 持仓成本可视化增强 ============
@app.get("/api/portfolio/cost_line/{code}")
def get_cost_line(code: str):
"""获取个股持仓成本线用于K线图标注"""
return pc.get_position_cost_lines(code)
@app.get("/api/portfolio/cost_distribution")
def get_cost_distribution():
"""获取持仓成本分布(盈亏区间图)"""
return pc.get_position_cost_distribution()
class EstimateCostRequest(BaseModel):
code: str
price: float
qty: int
side: str = "buy"
@app.post("/api/portfolio/estimate_cost")
def estimate_cost(req: EstimateCostRequest):
"""估算交易成本(下单前预估)"""
return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side)
@app.get("/api/portfolio/cost_breakdown/{code}")
def get_cost_breakdown(code: str):
"""获取持仓的详细成本拆解"""
return pc.get_cost_breakdown_for_position(code)
# ============ 交易日历与关键事件 ============
@app.get("/api/calendar/events")
def calendar_events(days: int = Query(30, ge=7, le=90)):
"""获取所有即将到来的关键事件(综合视图)"""
return cal.get_all_upcoming_events(days)
@app.get("/api/calendar/dividends")
def calendar_dividends(days: int = Query(30, ge=7, le=90)):
"""除权除息日历(持仓股优先)"""
return cal.get_upcoming_dividends(days)
@app.get("/api/calendar/unlock")
def calendar_unlock(days: int = Query(90, ge=7, le=180)):
"""限售解禁日历"""
return cal.get_unlock_calendar(days)
@app.get("/api/calendar/earnings")
def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)):
"""财报披露日历"""
return cal.get_earnings_calendar(days, holding_only)
@app.post("/api/calendar/check_alerts")
def calendar_check_alerts(current_user = Depends(require_admin)):
"""手动触发日历事件预警推送"""
return cal.check_and_push_calendar_alerts()
# ============ 数据修正与回填增强 ============
class UpdateQuoteRequest(BaseModel):
open: Optional[float] = None
high: Optional[float] = None
low: Optional[float] = None
close: Optional[float] = None
volume: Optional[int] = None
amount: Optional[float] = None
@app.delete("/api/data/quote/{code}/{date}")
def delete_quote(code: str, date: str, current_user = Depends(require_admin)):
"""删除指定股票指定日期的日线"""
return dm.delete_quote(code, date)
@app.put("/api/data/quote/{code}/{date}")
def update_quote(code: str, date: str, req: UpdateQuoteRequest,
current_user = Depends(require_admin)):
"""修正指定日线数据"""
return dm.update_quote(code, date, req.model_dump(exclude_none=True))
class DeleteRangeRequest(BaseModel):
start: str
end: str
@app.delete("/api/data/quotes/{code}/range")
def delete_quotes_range(code: str, req: DeleteRangeRequest,
current_user = Depends(require_admin)):
"""删除指定股票日期范围内的日线数据"""
return dm.delete_quotes_range(code, req.start, req.end)
@app.post("/api/data/refetch/{code}")
def refetch_quote(code: str, days: int = Query(60, ge=5, le=500),
current_user = Depends(require_admin)):
"""重新抓取指定股票日线(覆盖更新)"""
return dm.refetch_quote(code, days)
@app.get("/api/data/integrity")
def check_integrity(days: int = Query(30, ge=7, le=90),
current_user = Depends(require_admin)):
"""数据完整性检查"""
return dm.check_data_integrity(days=days)
@app.post("/api/data/auto_fix")
def auto_fix_missing(limit: int = Query(50, ge=10, le=200),
current_user = Depends(require_admin)):
"""自动补齐缺失数据"""
t = __import__("threading").Thread(
target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True
)
t.start()
return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"}
@app.get("/api/data/refill_progress")
def refill_progress(task_id: str = Query("default")):
"""获取回填进度"""
return dm.get_refill_progress(task_id)
@app.post("/api/data/refill_resume")
def refill_resume(days: int = Query(250, ge=30, le=1000),
task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""带断点续传的全市场回填(后台执行)"""
import threading
t = threading.Thread(
target=dm.start_refill_with_resume,
kwargs={"days": days, "task_id": task_id},
daemon=True
)
t.start()
return {"ok": True, "msg": f"已启动断点续传回填,天数={days}任务ID={task_id}"}
@app.delete("/api/data/refill_progress")
def clear_refill_progress(task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""清除回填进度(从头开始)"""
return dm.clear_refill_progress(task_id)
@app.get("/api/data/quality_report")
def data_quality_report(current_user = Depends(require_admin)):
"""数据质量报告"""
return dm.get_data_quality_report()
@app.get("/api/portfolio/attribution")
def portfolio_attribution():
"""持仓归因分析"""
return attrib.analyze_attribution()
# ============ AI 对话式分析 ============
class ChatRequest(BaseModel):
session_id: str
message: str
@app.post("/api/chat")
def chat(req: ChatRequest):
"""AI对话"""
return ai_chat.chat(req.session_id, req.message)
@app.delete("/api/chat/{session_id}")
def clear_chat(session_id: str):
"""清空会话"""
ai_chat.clear_session(session_id)
return {"ok": True}
@app.get("/api/chat/{session_id}/history")
def chat_history(session_id: str):
"""获取会话历史"""
return {"ok": True, "messages": ai_chat.get_session_history(session_id)}
# ============ 社区情绪监控 ============
@app.post("/api/sentiment/collect")
def sentiment_collect(limit: int = Query(50, ge=10, le=200)):
"""采集社区帖子"""
return sentiment.collect_posts(limit)
@app.get("/api/sentiment/index")
def sentiment_index(date: Optional[str] = None):
"""获取情绪指数"""
d = dt.date.fromisoformat(date) if date else None
return sentiment.calculate_sentiment_index(d)
@app.get("/api/sentiment/hot_stocks")
def sentiment_hot_stocks(days: int = Query(1, ge=1, le=7), limit: int = Query(20, le=50)):
"""热议股票排行"""
return sentiment.get_hot_stocks(days, limit)
@app.get("/api/sentiment/history")
def sentiment_history(days: int = Query(30, ge=7, le=90)):
"""情绪指数历史"""
return sentiment.get_sentiment_history(days)
@app.get("/api/sentiment/correlation")
def sentiment_correlation(code: str = Query(...), days: int = Query(60, ge=20, le=180)):
"""情绪与股价相关性"""
return sentiment.analyze_sentiment_correlation(code, days)
@app.get("/api/sentiment/wordcloud")
def sentiment_wordcloud(days: int = Query(7, ge=1, le=30), top_n: int = Query(50, le=100)):
"""关键词云"""
return sentiment.get_keyword_cloud(days, top_n)
# ============ 事件驱动策略 ============
@app.post("/api/events/seed")
def events_seed():
"""生成示例事件数据"""
return events.seed_sample_events()
@app.get("/api/events/earnings/pattern")
def earnings_pattern(days_before: int = Query(5, ge=1, le=10), days_after: int = Query(10, ge=5, le=30)):
"""财报发布前后统计规律"""
return events.analyze_earnings_pattern(days_before, days_after)
@app.get("/api/events/insider")
def insider_trading(code: Optional[str] = None, days: int = Query(180, ge=30, le=365)):
"""高管增减持跟踪"""
return events.track_insider_trading(code, days)
@app.get("/api/events/unlock")
def unlock_impact(days: int = Query(90, ge=30, le=180)):
"""限售解禁影响分析"""
return events.analyze_unlock_impact(days)
@app.get("/api/events/policy")
def policy_events(sector: Optional[str] = None, days: int = Query(180, ge=30, le=365)):
"""行业政策事件"""
return events.get_policy_events(sector, days)
class EventSelectorRequest(BaseModel):
event_types: List[str]
days: int = 30
@app.post("/api/events/selector")
def event_selector(req: EventSelectorRequest):
"""事件驱动选股"""
return events.event_driven_selector(req.event_types, req.days)
# ============ 财报深度解读 ============
@app.post("/api/financial/seed")
def financial_seed():
"""生成示例财报数据"""
return fin.seed_sample_reports()
@app.get("/api/financial/trend")
def financial_trend(code: str = Query(...), periods: int = Query(8, ge=4, le=16)):
"""财报关键指标趋势"""
return fin.get_report_trend(code, periods)
@app.get("/api/financial/summary")
def financial_summary(code: str = Query(...)):
"""AI财报摘要"""
return fin.generate_ai_summary(code)
@app.get("/api/financial/compare")
def financial_compare(code: str = Query(...), sector: Optional[str] = None):
"""同行对比"""
return fin.compare_with_peers(code, sector)
@app.get("/api/financial/warnings")
def financial_warnings(code: str = Query(...)):
"""财报异常预警"""
return fin.detect_abnormalities(code)
@app.get("/api/financial/calendar")
def financial_calendar(days: int = Query(30, ge=7, le=90)):
"""财报发布日历"""
return fin.get_report_calendar(days)
@app.get("/api/financial/rankings")
def financial_rankings(metric: str = Query("roe"), limit: int = Query(20, le=50)):
"""财报排行榜"""
return fin.get_top_reports(metric, limit)
# ============ 涨跌停分析 ============
@app.get("/api/limit/stocks")
def limit_stocks(date: Optional[str] = None, limit_type: str = Query("up")):
"""获取涨停/跌停股票"""
d = dt.date.fromisoformat(date) if date else None
return limit_up.get_limit_stocks(d, limit_type)
@app.get("/api/limit/consecutive")
def consecutive_limits(days: int = Query(10, ge=5, le=30)):
"""连板股追踪"""
return limit_up.track_consecutive_limits(days)
@app.get("/api/limit/break_rate")
def limit_break_rate(days: int = Query(60, ge=30, le=180)):
"""炸板率统计"""
return limit_up.analyze_limit_break_rate(days)
@app.get("/api/limit/squad")
def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5, ge=3, le=10)):
"""涨停敢死队排行"""
return limit_up.get_limit_squad_rankings(days, min_limits)
@app.get("/api/limit/consecutive_calendar")
def consecutive_calendar(days: int = Query(60, ge=20, le=120)):
"""连板日历:记录连板历史,分析几进几出规律"""
return limit_up.get_consecutive_calendar(days)
@app.get("/api/limit/post_break")
def post_break_performance(days: int = Query(90, ge=30, le=180)):
"""炸板后 1/3/5 日走势统计"""
return limit_up.analyze_post_break_performance(days)
@app.get("/api/limit/reasons")
def limit_reasons(date: Optional[str] = None):
"""涨停原因分类(情绪/题材/业绩/政策等)"""
d = dt.date.fromisoformat(date) if date else None
return limit_up.classify_limit_reasons(d)
# ============ 推送通知 ============
@app.get("/api/notify/status")
def notify_status():
return {"channels": notifier.channels_status(), "enabled": notifier.any_enabled()}
@app.post("/api/notify/test")
def notify_test():
if not notifier.any_enabled():
return {"ok": False, "msg": "未配置任何推送渠道,请在 backend/.env 配置后重启"}
res = notifier.notify("【Blackdata】推送测试", "这是一条来自Blackdata股票终端的测试通知收到即表示推送通道正常。")
return {"ok": True, "result": res}
# ============ 智能预警 ============
class AlertIn(BaseModel):
code: str
kind: str = "price_above"
threshold: float
note: str = ""
@app.get("/api/alerts")
def list_alerts():
with get_session() as s:
rows = s.execute(select(AlertRule).order_by(AlertRule.id.desc())).scalars().all()
return {"list": [{"id": r.id, "code": r.code, "name": r.name, "kind": r.kind,
"threshold": r.threshold, "status": r.status, "note": r.note,
"last_value": r.last_value,
"triggered_at": r.triggered_at.strftime("%m-%d %H:%M") if r.triggered_at else ""}
for r in rows]}
@app.post("/api/alerts")
def add_alert(a: AlertIn):
with get_session() as s:
sec = s.get(Security, a.code)
name = sec.name if sec else a.code
row = AlertRule(code=a.code, name=name, kind=a.kind, threshold=a.threshold, note=a.note)
s.add(row); s.commit()
return {"ok": True, "id": row.id}
@app.delete("/api/alerts/{aid}")
def del_alert(aid: int):
with get_session() as s:
row = s.get(AlertRule, aid)
if row:
s.delete(row); s.commit()
return {"ok": True}
@app.post("/api/alerts/{aid}/reactivate")
def reactivate_alert(aid: int):
with get_session() as s:
row = s.get(AlertRule, aid)
if row:
row.status = "active"; row.triggered_at = None; s.commit()
return {"ok": True}
@app.post("/api/alerts/check")
def manual_check():
return al.check_alerts()
@app.get("/api/alerts/events")
def alert_events(unread_only: bool = Query(False), limit: int = Query(30, le=100)):
with get_session() as s:
stmt = select(AlertEvent).order_by(AlertEvent.id.desc())
if unread_only:
stmt = stmt.where(AlertEvent.read.is_(False))
rows = s.execute(stmt.limit(limit)).scalars().all()
unread = s.execute(select(func.count()).select_from(AlertEvent).where(AlertEvent.read.is_(False))).scalar() or 0
return {"unread": unread, "list": [{"id": e.id, "code": e.code, "name": e.name,
"message": e.message, "time": e.created_at.strftime("%m-%d %H:%M:%S") if e.created_at else ""}
for e in rows]}
@app.post("/api/alerts/events/read")
def mark_events_read():
with get_session() as s:
for e in s.execute(select(AlertEvent).where(AlertEvent.read.is_(False))).scalars():
e.read = True
s.commit()
return {"ok": True}
# ============ 资讯中心 ============
@app.get("/api/news")
def news(limit: int = Query(40, le=100)):
return svc.get_news(limit)
@app.get("/api/news/stock")
def news_stock(code: str = Query(...)):
return svc.get_stock_news(code)
@app.get("/api/news/watch")
def news_watch():
codes = load_watch()[:6]
out = []
for c in codes:
r = svc.get_stock_news(c, limit=4)
for x in r["list"]:
x["code"] = c
out.append(x)
out.sort(key=lambda x: x["time"], reverse=True)
return {"list": out[:40]}
class NewsAI(BaseModel):
title: str
content: str = ""
@app.post("/api/news/ai")
def news_ai(n: NewsAI):
text_in = (n.title + "" + n.content).strip()
senti, kw = svc.judge_sentiment(text_in)
if llm.enabled():
try:
prompt = ("请分析下面这条财经资讯:\n"
"1) 一句话摘要2) 利好/利空/中性判断及理由3) 可能受影响的板块或个股方向。120字内。\n\n"
+ text_in[:1200])
text = llm.ask(prompt, temperature=0.3, max_tokens=400)
return {"ok": True, "source": "llm", "sentiment": senti, "text": text}
except Exception:
pass
return {"ok": True, "source": "rule", "sentiment": senti,
"text": f"判断:{senti}(关键词:{''.join(kw) or ''})。摘要:{text_in[:80]}\n(配置大模型后可获得更深入的关联分析)"}
# ============ 盘中实时监控雷达 ============
@app.get("/api/radar/status")
def radar_status():
"""雷达状态。"""
return {"trading_time": radar._is_trading_time()}
@app.post("/api/radar/scan")
def radar_scan():
"""手动触发异动扫描。"""
return radar.scan_all()
@app.get("/api/radar/events")
def radar_events(hours: int = Query(2, ge=1, le=24), limit: int = Query(50, le=200)):
"""获取最近的异动事件。"""
return {"list": radar.get_recent_events(hours, limit)}
@app.post("/api/radar/notify")
def radar_notify():
"""推送未通知的异动。"""
return radar.notify_events()
@app.get("/api/radar/stats")
def radar_stats(date: str = Query(None)):
"""异动统计。"""
d = dt.date.fromisoformat(date) if date else None
return radar.get_statistics(d)
# ============ 板块轮动分析 ============
@app.get("/api/sector/trend")
def sector_trend(days: int = Query(20, ge=5, le=60), top_n: int = Query(15, le=30)):
"""板块强弱趋势"""
return sector.get_sector_trend(days, top_n)
@app.get("/api/sector/flow")
def sector_flow(days: int = Query(5, ge=1, le=20)):
"""资金流向分析"""
return sector.analyze_fund_flow(days)
@app.get("/api/sector/lifecycle")
def sector_lifecycle(name: str = Query(...), days: int = Query(60, ge=20, le=120)):
"""板块生命周期"""
return sector.analyze_lifecycle(name, days)
@app.get("/api/sector/leaders")
def sector_leaders(name: str = Query(...), days: int = Query(20, ge=5, le=60), limit: int = Query(10, le=30)):
"""龙头股识别"""
return sector.identify_leaders(name, days, limit)
@app.get("/api/sector/correlation")
def sector_correlation(days: int = Query(60, ge=20, le=120), top_n: int = Query(20, le=30)):
"""板块联动性分析"""
return sector.analyze_correlation(days, top_n)
@app.get("/api/sector/summary")
def sector_summary():
"""板块轮动摘要"""
return sector.get_rotation_summary()
# ============ 智能选股增强 ============
@app.get("/api/selector/fields")
def selector_fields():
"""获取可用字段"""
return {"ok": True, "fields": selector.get_available_fields()}
@app.get("/api/selector/presets")
def selector_presets():
"""获取预设策略"""
return {"ok": True, "presets": selector.get_preset_strategies()}
class SelectorRequest(BaseModel):
strategy: Dict[str, Any]
date: Optional[str] = None
@app.post("/api/selector/run")
def selector_run(req: SelectorRequest):
"""执行选股"""
try:
strategy = selector.Strategy.from_dict(req.strategy)
date = dt.date.fromisoformat(req.date) if req.date else None
return selector.run_selector(strategy, date)
except Exception as e:
return {"ok": False, "msg": str(e)}
@app.post("/api/selector/backtest")
def selector_backtest(req: SelectorRequest, days: int = Query(60, ge=20, le=250)):
"""选股策略回测"""
try:
strategy = selector.Strategy.from_dict(req.strategy)
return selector.backtest_selector(strategy, days)
except Exception as e:
return {"ok": False, "msg": str(e)}
class CompareRequest(BaseModel):
strategy: Dict[str, Any]
date1: str
date2: str
@app.post("/api/selector/compare")
def selector_compare(req: CompareRequest):
"""对比选股结果"""
try:
strategy = selector.Strategy.from_dict(req.strategy)
date1 = dt.date.fromisoformat(req.date1)
date2 = dt.date.fromisoformat(req.date2)
return selector.compare_results(date1, date2, strategy)
except Exception as e:
return {"ok": False, "msg": str(e)}
@app.get("/api/selector/strategies")
def list_strategies():
"""获取保存的策略列表"""
with get_session() as s:
rows = s.execute(
select(SelectorStrategy).order_by(SelectorStrategy.updated_at.desc())
).scalars().all()
return {
"ok": True,
"strategies": [{
"id": r.id,
"name": r.name,
"description": r.description,
"is_preset": r.is_preset,
"created_at": r.created_at.strftime("%Y-%m-%d %H:%M:%S"),
"updated_at": r.updated_at.strftime("%Y-%m-%d %H:%M:%S")
} for r in rows]
}
class SaveStrategyRequest(BaseModel):
name: str
description: str = ""
strategy: Dict[str, Any]
@app.post("/api/selector/strategies")
def save_strategy(req: SaveStrategyRequest):
"""保存策略"""
try:
strategy = selector.Strategy.from_dict(req.strategy)
with get_session() as s:
record = SelectorStrategy(
name=req.name,
description=req.description,
strategy_json=strategy.to_json()
)
s.add(record)
s.commit()
return {"ok": True, "id": record.id}
except Exception as e:
return {"ok": False, "msg": str(e)}
@app.get("/api/selector/strategies/{sid}")
def get_strategy(sid: int):
"""获取策略详情"""
with get_session() as s:
record = s.get(SelectorStrategy, sid)
if not record:
return {"ok": False, "msg": "策略不存在"}
return {
"ok": True,
"id": record.id,
"name": record.name,
"description": record.description,
"strategy": json.loads(record.strategy_json)
}
@app.delete("/api/selector/strategies/{sid}")
def delete_strategy(sid: int):
"""删除策略"""
with get_session() as s:
record = s.get(SelectorStrategy, sid)
if record:
s.delete(record)
s.commit()
return {"ok": True}
@app.get("/api/selector/alerts")
def list_selector_alerts():
"""获取选股预警列表"""
with get_session() as s:
rows = s.execute(
select(SelectorAlert).order_by(SelectorAlert.id.desc())
).scalars().all()
return {
"ok": True,
"alerts": [{
"id": r.id,
"strategy_id": r.strategy_id,
"strategy_name": r.strategy_name,
"status": r.status,
"last_checked": r.last_checked.strftime("%m-%d %H:%M") if r.last_checked else "",
"last_count": r.last_count
} for r in rows]
}
class CreateAlertRequest(BaseModel):
strategy_id: int
strategy_name: str
@app.post("/api/selector/alerts")
def create_selector_alert(req: CreateAlertRequest):
"""创建选股预警"""
with get_session() as s:
record = SelectorAlert(
strategy_id=req.strategy_id,
strategy_name=req.strategy_name
)
s.add(record)
s.commit()
return {"ok": True, "id": record.id}
@app.delete("/api/selector/alerts/{aid}")
def delete_selector_alert(aid: int):
"""删除选股预警"""
with get_session() as s:
record = s.get(SelectorAlert, aid)
if record:
s.delete(record)
s.commit()
return {"ok": True}
# ============ 模拟盘 ============
class PaperAccountIn(BaseModel):
name: str
initial_cash: float = 1_000_000.0
@app.get("/api/paper/accounts")
def paper_list_accounts():
return {"ok": True, "accounts": paper.list_accounts()}
@app.post("/api/paper/accounts")
def paper_create_account(req: PaperAccountIn):
return paper.create_account(req.name, req.initial_cash)
@app.post("/api/paper/accounts/{account_id}/reset")
def paper_reset_account(account_id: int, initial_cash: Optional[float] = None):
return paper.reset_account(account_id, initial_cash)
class PaperOrderIn(BaseModel):
code: str
side: str # buy / sell
qty: int
price: Optional[float] = None
reason: str = ""
@app.post("/api/paper/accounts/{account_id}/order")
def paper_place_order(account_id: int, req: PaperOrderIn):
return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason)
@app.get("/api/paper/accounts/{account_id}/portfolio")
def paper_get_portfolio(account_id: int):
return paper.get_portfolio(account_id)
@app.get("/api/paper/accounts/{account_id}/trades")
def paper_get_trades(account_id: int, limit: int = Query(100, le=500)):
return {"ok": True, "trades": paper.get_trades(account_id, limit)}
# ============ 静态前端 ============
FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype")
if os.path.isdir(FRONTEND_DIR):
app.mount("/", StaticFiles(directory=FRONTEND_DIR, html=True), name="frontend")
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)