1658 lines
54 KiB
Python
1658 lines
54 KiB
Python
"""Blackdata股票终端 — FastAPI 后端入口。
|
||
|
||
- /api/* : 数据接口(基于 AkShare,带缓存与降级)
|
||
- / : 托管前端原型(prototype 目录)
|
||
"""
|
||
import os
|
||
import json
|
||
import datetime as dt
|
||
from contextlib import asynccontextmanager
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
from fastapi import FastAPI, Query, Depends
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.exceptions import RequestValidationError
|
||
from sqlalchemy.exc import SQLAlchemyError
|
||
from fastapi.staticfiles import StaticFiles
|
||
from sqlalchemy import select, func, desc
|
||
|
||
from pydantic import BaseModel
|
||
|
||
import akshare_service as svc
|
||
import redis_cache
|
||
from redis_cache import cache
|
||
import auth
|
||
from auth import get_current_user, require_auth, require_admin
|
||
import init_auth
|
||
import exceptions
|
||
from exceptions import (
|
||
BusinessException,
|
||
DataSourceException,
|
||
business_exception_handler,
|
||
validation_exception_handler,
|
||
sqlalchemy_exception_handler,
|
||
general_exception_handler
|
||
)
|
||
import config
|
||
import scheduler
|
||
import backtest as bt
|
||
import backtest_advanced as bta
|
||
import ai
|
||
import signals as sig
|
||
import report as rpt
|
||
import portfolio as pf
|
||
import llm
|
||
import alerts as al
|
||
import notifier
|
||
import intraday_radar as radar
|
||
import sector_rotation as sector
|
||
import smart_selector as selector
|
||
import attribution_analysis as attrib
|
||
import ai_chat
|
||
import sentiment_monitor as sentiment
|
||
import event_driven as events
|
||
import financial_analysis as fin
|
||
import limit_analysis as limit_up
|
||
import watchlist_manager as wl
|
||
import position_cost as pc
|
||
import trade_calendar as cal
|
||
import data_manager as dm
|
||
import paper_trading as paper
|
||
from db import init_db, get_session
|
||
from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
|
||
SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade,
|
||
AlertRule, AlertEvent, SelectorStrategy, SelectorAlert)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
try:
|
||
init_db()
|
||
init_auth.init_default_admin()
|
||
wl.init_default_groups()
|
||
paper.ensure_default_account()
|
||
scheduler.start_scheduler()
|
||
print("[startup] db + scheduler + auth ready")
|
||
except Exception as e:
|
||
print("[startup] WARN:", repr(e)[:160])
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan)
|
||
|
||
# 注册异常处理器
|
||
app.add_exception_handler(BusinessException, business_exception_handler)
|
||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
|
||
app.add_exception_handler(Exception, general_exception_handler)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 自选股本地存储
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
WATCH_FILE = os.path.join(BASE_DIR, "watchlist.json")
|
||
DEFAULT_WATCH = ["600519", "300750", "002594", "688981", "300059", "601012"]
|
||
|
||
|
||
def load_watch():
|
||
if os.path.exists(WATCH_FILE):
|
||
try:
|
||
with open(WATCH_FILE, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
pass
|
||
return DEFAULT_WATCH
|
||
|
||
|
||
def save_watch(symbols):
|
||
with open(WATCH_FILE, "w", encoding="utf-8") as f:
|
||
json.dump(symbols, f, ensure_ascii=False)
|
||
|
||
|
||
# ============ 认证相关 API ============
|
||
class LoginRequest(BaseModel):
|
||
username: str
|
||
password: str
|
||
|
||
@app.post("/api/auth/login")
|
||
def login(req: LoginRequest, db = Depends(get_session)):
|
||
"""用户登录"""
|
||
user = auth.authenticate_user(db, req.username, req.password)
|
||
if not user:
|
||
raise exceptions.AuthException("用户名或密码错误")
|
||
|
||
access_token = auth.create_access_token(data={"sub": user.username})
|
||
return {
|
||
"ok": True,
|
||
"access_token": access_token,
|
||
"token_type": "bearer",
|
||
"username": user.username,
|
||
"is_admin": user.is_admin
|
||
}
|
||
|
||
@app.get("/api/auth/me")
|
||
def get_me(current_user = Depends(require_auth)):
|
||
"""获取当前用户信息"""
|
||
return {
|
||
"ok": True,
|
||
"username": current_user.username,
|
||
"is_admin": current_user.is_admin
|
||
}
|
||
|
||
class ChangePasswordRequest(BaseModel):
|
||
old_password: str
|
||
new_password: str
|
||
|
||
@app.post("/api/auth/change-password")
|
||
def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)):
|
||
"""修改密码"""
|
||
if not auth.verify_password(req.old_password, current_user.hashed_password):
|
||
raise exceptions.AuthException("原密码错误")
|
||
|
||
current_user.hashed_password = auth.get_password_hash(req.new_password)
|
||
db.commit()
|
||
return {"ok": True, "msg": "密码修改成功"}
|
||
|
||
# ============ 用户管理 ============
|
||
class CreateUserRequest(BaseModel):
|
||
username: str
|
||
password: str
|
||
is_admin: bool = False
|
||
|
||
from models import User as UserModel
|
||
|
||
@app.get("/api/users")
|
||
def list_users(current_user = Depends(require_admin)):
|
||
with get_session() as s:
|
||
rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all()
|
||
return {"ok": True, "users": [{"id": r.id, "username": r.username,
|
||
"is_admin": r.is_admin, "is_active": r.is_active,
|
||
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]}
|
||
|
||
@app.post("/api/users")
|
||
def create_user(req: CreateUserRequest, current_user = Depends(require_admin)):
|
||
with get_session() as s:
|
||
if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none():
|
||
return {"ok": False, "msg": "用户名已存在"}
|
||
user = UserModel(username=req.username,
|
||
hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin)
|
||
s.add(user); s.commit()
|
||
return {"ok": True, "id": user.id}
|
||
|
||
@app.delete("/api/users/{uid}")
|
||
def delete_user(uid: int, current_user = Depends(require_admin)):
|
||
if current_user.id == uid:
|
||
return {"ok": False, "msg": "不能删除自己"}
|
||
with get_session() as s:
|
||
u = s.get(UserModel, uid)
|
||
if u: s.delete(u); s.commit()
|
||
return {"ok": True}
|
||
|
||
@app.put("/api/users/{uid}/toggle_admin")
|
||
def toggle_admin(uid: int, current_user = Depends(require_admin)):
|
||
if current_user.id == uid:
|
||
return {"ok": False, "msg": "不能修改自己的权限"}
|
||
with get_session() as s:
|
||
u = s.get(UserModel, uid)
|
||
if not u: return {"ok": False, "msg": "用户不存在"}
|
||
u.is_admin = not u.is_admin; s.commit()
|
||
return {"ok": True, "is_admin": u.is_admin}
|
||
|
||
@app.put("/api/users/{uid}/reset_password")
|
||
def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)):
|
||
with get_session() as s:
|
||
u = s.get(UserModel, uid)
|
||
if not u: return {"ok": False, "msg": "用户不存在"}
|
||
u.hashed_password = auth.get_password_hash(req.new_password); s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
# ============ API ============
|
||
@app.get("/api/health")
|
||
def health():
|
||
return {
|
||
"ok": True,
|
||
"akshare": svc.AK_OK,
|
||
"redis": cache.enabled,
|
||
"auth": True
|
||
}
|
||
|
||
|
||
@app.get("/api/indices")
|
||
def indices():
|
||
return svc.get_indices()
|
||
|
||
|
||
@app.get("/api/kline")
|
||
def kline(symbol: str = Query("600519"), days: int = Query(120, ge=20, le=500)):
|
||
return svc.get_kline(symbol, days)
|
||
|
||
|
||
@app.get("/api/sentiment")
|
||
def sentiment():
|
||
return svc.get_sentiment()
|
||
|
||
|
||
@app.get("/api/treemap")
|
||
def treemap(mode: str = Query("sector"), date: str = Query(None)):
|
||
if mode == "sector" and date:
|
||
# 从数据库读历史板块数据
|
||
try:
|
||
target = dt.date.fromisoformat(date)
|
||
except Exception:
|
||
return svc.get_treemap(mode)
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(SectorDaily).where(SectorDaily.date == target)
|
||
.order_by(SectorDaily.pct.desc())
|
||
).scalars().all()
|
||
if not rows:
|
||
# 找最近有数据的日期
|
||
latest = s.execute(select(func.max(SectorDaily.date))).scalar()
|
||
if latest:
|
||
rows = s.execute(
|
||
select(SectorDaily).where(SectorDaily.date == latest)
|
||
.order_by(SectorDaily.pct.desc())
|
||
).scalars().all()
|
||
target = latest
|
||
if rows:
|
||
items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows]
|
||
return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items}
|
||
return svc.get_treemap(mode)
|
||
|
||
|
||
@app.get("/api/treemap/us")
|
||
def treemap_us():
|
||
return svc.get_us_treemap()
|
||
|
||
|
||
@app.get("/api/treemap/hk")
|
||
def treemap_hk():
|
||
return svc.get_hk_treemap()
|
||
|
||
|
||
@app.get("/api/treemap/sector_stocks")
|
||
def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)):
|
||
return svc.get_sector_stocks(name, limit)
|
||
|
||
|
||
@app.get("/api/treemap/all_leaders")
|
||
def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)):
|
||
from models import SectorLeader
|
||
# 优先从数据库读
|
||
with get_session() as s:
|
||
target = None
|
||
if date:
|
||
try: target = dt.date.fromisoformat(date)
|
||
except Exception: pass
|
||
if not target:
|
||
target = s.execute(select(func.max(SectorLeader.date))).scalar()
|
||
if target:
|
||
rows = s.execute(
|
||
select(SectorLeader).where(SectorLeader.date == target)
|
||
.order_by(SectorLeader.sector, SectorLeader.rank)
|
||
).scalars().all()
|
||
if rows:
|
||
sectors = {}
|
||
for r in rows:
|
||
sectors.setdefault(r.sector, []).append({
|
||
"code": r.code, "name": r.name, "pct": r.pct,
|
||
"price": r.price, "amount": r.amount
|
||
})
|
||
return {"source": "db", "date": target.isoformat(), "sectors": sectors}
|
||
# 降级到实时
|
||
return svc.get_all_sector_leaders(top_n)
|
||
|
||
|
||
@app.get("/api/fundflow")
|
||
def fundflow():
|
||
return svc.get_fund_flow()
|
||
|
||
|
||
@app.get("/api/hot/stocks")
|
||
def hot_stocks():
|
||
return svc.get_hot_stocks()
|
||
|
||
|
||
@app.get("/api/hot/sectors")
|
||
def hot_sectors():
|
||
return svc.get_industry_boards()
|
||
|
||
|
||
@app.get("/api/dragon")
|
||
def dragon():
|
||
return svc.get_dragon_tiger()
|
||
|
||
|
||
@app.get("/api/watchlist")
|
||
def watchlist():
|
||
return svc.get_watchlist(load_watch())
|
||
|
||
|
||
@app.post("/api/watchlist/{code}")
|
||
def watch_add(code: str):
|
||
w = load_watch()
|
||
if code not in w:
|
||
w.append(code)
|
||
save_watch(w)
|
||
return {"ok": True, "list": w}
|
||
|
||
|
||
@app.delete("/api/watchlist/{code}")
|
||
def watch_del(code: str):
|
||
w = [c for c in load_watch() if c != code]
|
||
save_watch(w)
|
||
return {"ok": True, "list": w}
|
||
|
||
|
||
# ============ 自选股分组管理 ============
|
||
class CreateGroupRequest(BaseModel):
|
||
name: str
|
||
description: str = ""
|
||
color: str = "blue"
|
||
|
||
@app.get("/api/watchlist/groups")
|
||
def list_groups():
|
||
"""获取所有分组"""
|
||
return {"ok": True, "groups": wl.get_all_groups()}
|
||
|
||
@app.post("/api/watchlist/groups")
|
||
def create_group(req: CreateGroupRequest):
|
||
"""创建新分组"""
|
||
return wl.create_group(req.name, req.description, req.color)
|
||
|
||
class UpdateGroupRequest(BaseModel):
|
||
name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
color: Optional[str] = None
|
||
|
||
@app.put("/api/watchlist/groups/{group_id}")
|
||
def update_group(group_id: int, req: UpdateGroupRequest):
|
||
"""更新分组信息"""
|
||
return wl.update_group(group_id, req.name, req.description, req.color)
|
||
|
||
@app.delete("/api/watchlist/groups/{group_id}")
|
||
def delete_group(group_id: int):
|
||
"""删除分组"""
|
||
return wl.delete_group(group_id)
|
||
|
||
class ReorderGroupsRequest(BaseModel):
|
||
group_ids: List[int]
|
||
|
||
@app.post("/api/watchlist/groups/reorder")
|
||
def reorder_groups(req: ReorderGroupsRequest):
|
||
"""重新排序分组"""
|
||
return wl.reorder_groups(req.group_ids)
|
||
|
||
@app.get("/api/watchlist/groups/{group_id}/stocks")
|
||
def get_group_stocks(group_id: int, with_quotes: bool = Query(True)):
|
||
"""获取分组内的股票"""
|
||
return wl.get_group_stocks(group_id, with_quotes)
|
||
|
||
class AddStockRequest(BaseModel):
|
||
code: str
|
||
note: str = ""
|
||
|
||
@app.post("/api/watchlist/groups/{group_id}/stocks")
|
||
def add_stock_to_group(group_id: int, req: AddStockRequest):
|
||
"""添加股票到分组"""
|
||
return wl.add_stock_to_group(group_id, req.code, req.note)
|
||
|
||
@app.delete("/api/watchlist/stocks/{item_id}")
|
||
def remove_stock(item_id: int):
|
||
"""从分组中移除股票"""
|
||
return wl.remove_stock_from_group(item_id)
|
||
|
||
class MoveStockRequest(BaseModel):
|
||
target_group_id: int
|
||
|
||
@app.post("/api/watchlist/stocks/{item_id}/move")
|
||
def move_stock(item_id: int, req: MoveStockRequest):
|
||
"""移动股票到另一个分组"""
|
||
return wl.move_stock_to_group(item_id, req.target_group_id)
|
||
|
||
class BatchAddRequest(BaseModel):
|
||
codes: List[str]
|
||
|
||
@app.post("/api/watchlist/groups/{group_id}/stocks/batch")
|
||
def batch_add_stocks(group_id: int, req: BatchAddRequest):
|
||
"""批量添加股票"""
|
||
return wl.batch_add_stocks(group_id, req.codes)
|
||
|
||
class UpdateNoteRequest(BaseModel):
|
||
note: str
|
||
|
||
@app.put("/api/watchlist/stocks/{item_id}/note")
|
||
def update_stock_note(item_id: int, req: UpdateNoteRequest):
|
||
"""更新股票备注"""
|
||
return wl.update_stock_note(item_id, req.note)
|
||
|
||
class ReorderStocksRequest(BaseModel):
|
||
item_ids: List[int]
|
||
|
||
@app.post("/api/watchlist/stocks/reorder")
|
||
def reorder_stocks(req: ReorderStocksRequest):
|
||
"""重新排序股票"""
|
||
return wl.reorder_stocks(req.item_ids)
|
||
|
||
@app.get("/api/watchlist/search")
|
||
def search_stocks(keyword: str = Query(..., min_length=1)):
|
||
"""跨分组搜索股票"""
|
||
return {"ok": True, "results": wl.search_stocks_across_groups(keyword)}
|
||
|
||
|
||
# ============ 数据中台 ============
|
||
@app.get("/api/admin/status")
|
||
def admin_status(current_user = Depends(require_admin)):
|
||
counts, last_dates = {}, {}
|
||
with get_session() as s:
|
||
for label, model in [("securities", Security), ("quotes_daily", DailyQuote),
|
||
("index_daily", IndexDaily), ("sector_daily", SectorDaily),
|
||
("fund_flow_daily", FundFlowDaily), ("sentiment_daily", SentimentDaily),
|
||
("dragon_tiger", DragonTiger)]:
|
||
counts[label] = s.execute(select(func.count()).select_from(model)).scalar() or 0
|
||
if hasattr(model, "date"):
|
||
d = s.execute(select(func.max(model.date))).scalar()
|
||
last_dates[label] = d.isoformat() if d else None
|
||
jobs = s.execute(select(JobRun).order_by(desc(JobRun.id)).limit(8)).scalars().all()
|
||
job_list = [{"id": j.id, "job": j.job, "status": j.status,
|
||
"started": j.started_at.strftime("%m-%d %H:%M:%S") if j.started_at else "",
|
||
"finished": j.finished_at.strftime("%H:%M:%S") if j.finished_at else "",
|
||
"message": j.message[:200]} for j in jobs]
|
||
return {"counts": counts, "last_dates": last_dates, "jobs": job_list,
|
||
"running": scheduler.is_running(), "universe": config.DEFAULT_UNIVERSE,
|
||
"schedule": f"周一至周五 {config.INGEST_HOUR:02d}:{config.INGEST_MINUTE:02d}"}
|
||
|
||
|
||
@app.post("/api/admin/ingest")
|
||
def admin_ingest(current_user = Depends(require_admin)):
|
||
if scheduler.is_running():
|
||
return {"started": False, "msg": "已有入库任务在执行"}
|
||
return scheduler.trigger_async()
|
||
|
||
|
||
@app.post("/api/admin/ingest_all")
|
||
def admin_ingest_all(current_user = Depends(require_admin)):
|
||
return scheduler.trigger_all_async()
|
||
|
||
|
||
@app.get("/api/db/kline")
|
||
def db_kline(symbol: str = Query("600519"), days: int = Query(250, ge=20, le=1000)):
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(DailyQuote).where(DailyQuote.code == symbol)
|
||
.order_by(DailyQuote.date.desc()).limit(days)
|
||
).scalars().all()
|
||
rows = list(reversed(rows))
|
||
if not rows:
|
||
return {"source": "db", "empty": True, "symbol": symbol, "dates": [], "ohlc": [], "vols": []}
|
||
return {"source": "db", "symbol": symbol,
|
||
"dates": [r.date.strftime("%m/%d") for r in rows],
|
||
"ohlc": [[r.open, r.close, r.low, r.high] for r in rows],
|
||
"vols": [r.volume for r in rows]}
|
||
|
||
|
||
@app.get("/api/db/sentiment_history")
|
||
def db_sentiment_history(days: int = Query(60, ge=5, le=365)):
|
||
with get_session() as s:
|
||
rows = s.execute(select(SentimentDaily).order_by(SentimentDaily.date.desc()).limit(days)).scalars().all()
|
||
rows = list(reversed(rows))
|
||
return {"dates": [r.date.isoformat() for r in rows],
|
||
"up": [r.up for r in rows], "down": [r.down for r in rows],
|
||
"limit_up": [r.limit_up for r in rows]}
|
||
|
||
|
||
@app.get("/api/review/daily")
|
||
def review_daily(date: str = Query(None)):
|
||
with get_session() as s:
|
||
if date:
|
||
d = dt.date.fromisoformat(date)
|
||
else:
|
||
d = s.execute(select(func.max(SectorDaily.date))).scalar()
|
||
if not d:
|
||
return {"ok": False, "msg": "暂无入库数据,请先在数据中台执行入库"}
|
||
sectors = s.execute(select(SectorDaily).where(SectorDaily.date == d).order_by(SectorDaily.pct.desc())).scalars().all()
|
||
flows = s.execute(select(FundFlowDaily).where(FundFlowDaily.date == d).order_by(FundFlowDaily.net.desc())).scalars().all()
|
||
senti = s.execute(select(SentimentDaily).where(SentimentDaily.date == d)).scalar_one_or_none()
|
||
lhb = s.execute(select(DragonTiger).where(DragonTiger.date == d).order_by(DragonTiger.net.desc()).limit(10)).scalars().all()
|
||
|
||
top_sec = [{"name": x.name, "pct": x.pct} for x in sectors[:8]]
|
||
bot_sec = [{"name": x.name, "pct": x.pct} for x in sectors[-5:]]
|
||
inflow = [{"name": x.name, "net": x.net} for x in flows[:8]]
|
||
outflow = [{"name": x.name, "net": x.net} for x in flows[-5:][::-1]]
|
||
senti_d = ({"up": senti.up, "down": senti.down, "limit_up": senti.limit_up,
|
||
"limit_down": senti.limit_down} if senti else None)
|
||
|
||
summary = _gen_review_text(d, senti_d, top_sec, inflow)
|
||
return {"ok": True, "date": d.isoformat(), "sentiment": senti_d,
|
||
"top_sectors": top_sec, "weak_sectors": bot_sec,
|
||
"inflow": inflow, "outflow": outflow,
|
||
"dragon": [{"name": x.name, "code": x.code, "net": x.net, "pct": x.pct} for x in lhb],
|
||
"summary": summary}
|
||
|
||
|
||
def _gen_review_text(d, senti, top_sec, inflow):
|
||
parts = [f"【{d.isoformat()} 复盘】"]
|
||
if senti:
|
||
tone = "情绪偏暖" if senti["up"] > senti["down"] else "情绪偏弱"
|
||
parts.append(f"全市场上涨 {senti['up']} 家、下跌 {senti['down']} 家,涨停 {senti['limit_up']} 家、跌停 {senti['limit_down']} 家,{tone}。")
|
||
if top_sec:
|
||
names = "、".join(x["name"] for x in top_sec[:3])
|
||
parts.append(f"领涨板块:{names}。")
|
||
if inflow:
|
||
names = "、".join(x["name"] for x in inflow[:3])
|
||
parts.append(f"主力净流入居前:{names}。")
|
||
parts.append("注:以上为基于入库数据的自动统计,AI 智能点评将在 AI 分析模块接入大模型后生成。")
|
||
return " ".join(parts)
|
||
|
||
|
||
@app.get("/api/backtest")
|
||
def backtest_api(symbol: str = Query("600519"), fast: int = Query(5, ge=2, le=60),
|
||
slow: int = Query(20, ge=5, le=250)):
|
||
if fast >= slow:
|
||
return {"ok": False, "msg": "快线周期需小于慢线周期"}
|
||
return bt.run_backtest(symbol, fast, slow)
|
||
|
||
|
||
# ============ 增强回测 ============
|
||
class BacktestParams(BaseModel):
|
||
symbol: str
|
||
strategy: str = "ma" # ma, multi_factor
|
||
fast: int = 5
|
||
slow: int = 20
|
||
position_size: float = 1.0
|
||
stop_loss: float = 0.0
|
||
take_profit: float = 0.0
|
||
initial_capital: float = 100000.0
|
||
commission: float = 0.0005
|
||
|
||
|
||
@app.post("/api/backtest/advanced")
|
||
def backtest_advanced(params: BacktestParams):
|
||
"""增强回测"""
|
||
if params.strategy == "ma":
|
||
strategy = bta.MAStrategy(
|
||
fast=params.fast,
|
||
slow=params.slow,
|
||
position_size=params.position_size,
|
||
stop_loss=params.stop_loss,
|
||
take_profit=params.take_profit
|
||
)
|
||
elif params.strategy == "multi_factor":
|
||
strategy = bta.MultiFactorStrategy(position_size=params.position_size)
|
||
else:
|
||
return {"ok": False, "msg": "不支持的策略类型"}
|
||
|
||
return bta.run_advanced_backtest(
|
||
symbol=params.symbol,
|
||
strategy=strategy,
|
||
initial_capital=params.initial_capital,
|
||
commission=params.commission
|
||
)
|
||
|
||
|
||
class OptimizeParams(BaseModel):
|
||
symbol: str
|
||
strategy: str = "ma"
|
||
fast_range: List[int] = [3, 5, 10, 15]
|
||
slow_range: List[int] = [20, 30, 60]
|
||
metric: str = "sharpe_ratio"
|
||
|
||
|
||
@app.post("/api/backtest/optimize")
|
||
def backtest_optimize(params: OptimizeParams):
|
||
"""参数优化"""
|
||
param_grid = {
|
||
"fast": params.fast_range,
|
||
"slow": params.slow_range
|
||
}
|
||
|
||
results = bta.optimize_parameters(
|
||
symbol=params.symbol,
|
||
param_grid=param_grid,
|
||
strategy_class=bta.MAStrategy,
|
||
metric=params.metric
|
||
)
|
||
|
||
return {
|
||
"ok": True,
|
||
"symbol": params.symbol,
|
||
"metric": params.metric,
|
||
"results": results[:20] # 返回前20个最优结果
|
||
}
|
||
|
||
|
||
class CompareParams(BaseModel):
|
||
symbol: str
|
||
strategies: List[Dict[str, Any]]
|
||
|
||
|
||
@app.post("/api/backtest/compare")
|
||
def backtest_compare(params: CompareParams):
|
||
"""策略对比"""
|
||
strategies = []
|
||
|
||
for s in params.strategies:
|
||
if s["type"] == "ma":
|
||
strategies.append(bta.MAStrategy(
|
||
fast=s.get("fast", 5),
|
||
slow=s.get("slow", 20),
|
||
stop_loss=s.get("stop_loss", 0),
|
||
take_profit=s.get("take_profit", 0)
|
||
))
|
||
elif s["type"] == "multi_factor":
|
||
strategies.append(bta.MultiFactorStrategy())
|
||
|
||
return bta.compare_strategies(params.symbol, strategies)
|
||
|
||
|
||
# ============ 全市场选股 ============
|
||
STRATEGIES = {
|
||
"surge": "最近暴涨(5日涨幅≥20%)",
|
||
"plunge": "最近暴跌(5日跌幅≥15%)",
|
||
"dip": "超跌抄底(60日分位≤20%且当日企稳)",
|
||
"breakout": "突破走强(逼近60日新高)",
|
||
"ma_bull": "均线多头(MA5>10>20)",
|
||
"volume": "放量上攻(量比≥2且上涨)",
|
||
"macd_gold": "MACD金叉",
|
||
"strong": "强势连涨(≥3日连阳)",
|
||
}
|
||
|
||
|
||
@app.get("/api/screen/strategies")
|
||
def screen_strategies():
|
||
return {"list": [{"id": k, "name": v} for k, v in STRATEGIES.items()]}
|
||
|
||
|
||
@app.get("/api/screen")
|
||
def screen(strategy: str = Query("surge"), limit: int = Query(60, ge=10, le=300),
|
||
min_amount: float = Query(0.0)):
|
||
M = StockMetric
|
||
q = select(M)
|
||
order = M.ret5.desc()
|
||
if strategy == "surge":
|
||
q = q.where(M.ret5 >= 20)
|
||
elif strategy == "plunge":
|
||
q = q.where(M.ret5 <= -15); order = M.ret5.asc()
|
||
elif strategy == "dip":
|
||
q = q.where(M.pos60 <= 0.2, M.pct > 0); order = M.pos60.asc()
|
||
elif strategy == "breakout":
|
||
q = q.where(M.pos60 >= 0.95, M.pct > 0); order = M.ret20.desc()
|
||
elif strategy == "ma_bull":
|
||
q = q.where(M.ma_bull.is_(True)); order = M.ret20.desc()
|
||
elif strategy == "volume":
|
||
q = q.where(M.vol_ratio >= 2, M.pct > 0); order = M.vol_ratio.desc()
|
||
elif strategy == "macd_gold":
|
||
q = q.where(M.macd_gold.is_(True)); order = M.ret5.desc()
|
||
elif strategy == "strong":
|
||
q = q.where(M.up_streak >= 3); order = M.up_streak.desc()
|
||
if min_amount > 0:
|
||
q = q.where(M.amount >= min_amount)
|
||
q = q.order_by(order).limit(limit)
|
||
with get_session() as s:
|
||
rows = s.execute(q).scalars().all()
|
||
total = s.execute(select(func.count()).select_from(M)).scalar() or 0
|
||
return {"strategy": strategy, "name": STRATEGIES.get(strategy, strategy), "pool_size": total,
|
||
"count": len(rows), "list": [{
|
||
"code": r.code, "name": r.name, "close": r.close, "pct": r.pct,
|
||
"ret5": r.ret5, "ret20": r.ret20, "vol_ratio": r.vol_ratio,
|
||
"rsi14": r.rsi14, "pos60": round(r.pos60 * 100, 1), "amount": r.amount,
|
||
"up_streak": r.up_streak} for r in rows]}
|
||
|
||
|
||
@app.get("/api/securities/search")
|
||
def securities_search(q: str = Query("", min_length=0), limit: int = Query(15, le=50)):
|
||
with get_session() as s:
|
||
stmt = select(Security)
|
||
if q:
|
||
stmt = stmt.where((Security.code.like(f"{q}%")) | (Security.name.like(f"%{q}%")))
|
||
rows = s.execute(stmt.limit(limit)).scalars().all()
|
||
return {"list": [{"code": r.code, "name": r.name} for r in rows]}
|
||
|
||
|
||
# ============ 个股复盘(K线 + 买卖点 + 回放) ============
|
||
def _ma_list(close, n):
|
||
out = [None] * len(close)
|
||
for i in range(len(close)):
|
||
if i >= n - 1:
|
||
out[i] = round(sum(close[i - n + 1:i + 1]) / n, 3)
|
||
return out
|
||
|
||
|
||
@app.get("/api/review/stock")
|
||
def review_stock(symbol: str = Query("600519"), days: int = Query(250, ge=40, le=1000),
|
||
fast: int = Query(5), slow: int = Query(20)):
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(DailyQuote).where(DailyQuote.code == symbol)
|
||
.order_by(DailyQuote.date.desc()).limit(days)
|
||
).scalars().all()
|
||
sec = s.get(Security, symbol)
|
||
rows = list(reversed(rows))
|
||
if not rows:
|
||
return {"ok": False, "msg": "该股票库内无日线,请先在数据中台入库该股或执行全市场回填", "symbol": symbol}
|
||
dates = [r.date.strftime("%y/%m/%d") for r in rows]
|
||
ohlc = [[r.open, r.close, r.low, r.high] for r in rows]
|
||
vols = [r.volume for r in rows]
|
||
close = [r.close for r in rows]
|
||
maf, mas = _ma_list(close, fast), _ma_list(close, slow)
|
||
|
||
signals = []
|
||
for i in range(1, len(close)):
|
||
if maf[i] is None or mas[i] is None or maf[i - 1] is None or mas[i - 1] is None:
|
||
continue
|
||
if maf[i - 1] <= mas[i - 1] and maf[i] > mas[i]:
|
||
signals.append({"idx": i, "date": dates[i], "price": close[i], "type": "buy"})
|
||
elif maf[i - 1] >= mas[i - 1] and maf[i] < mas[i]:
|
||
signals.append({"idx": i, "date": dates[i], "price": close[i], "type": "sell"})
|
||
|
||
# 区间统计
|
||
hi = max(r.high for r in rows); lo = min(r.low for r in rows)
|
||
period_ret = round((close[-1] / close[0] - 1) * 100, 2)
|
||
return {"ok": True, "symbol": symbol, "name": sec.name if sec else symbol,
|
||
"dates": dates, "ohlc": ohlc, "vols": vols,
|
||
"ma_fast": maf, "ma_slow": mas, "fast": fast, "slow": slow,
|
||
"signals": signals,
|
||
"stats": {"period_return": period_ret, "high": hi, "low": lo,
|
||
"start": dates[0], "end": dates[-1], "bars": len(rows)}}
|
||
|
||
|
||
# ============ AI 分析 ============
|
||
@app.get("/api/ai/status")
|
||
def ai_status():
|
||
return {"enabled": llm.enabled(), "model": config.LLM_MODEL if llm.enabled() else None}
|
||
|
||
|
||
@app.get("/api/ai/review_daily")
|
||
def ai_review_daily(date: str = Query(None)):
|
||
return ai.review_daily_comment(date)
|
||
|
||
|
||
@app.get("/api/ai/diagnose")
|
||
def ai_diagnose(symbol: str = Query("600519")):
|
||
return ai.diagnose(symbol)
|
||
|
||
|
||
@app.get("/api/ai/today")
|
||
def ai_today():
|
||
return ai.today_strategy()
|
||
|
||
|
||
@app.get("/api/ai/trend_analysis")
|
||
def ai_trend_analysis(
|
||
symbol: str = Query(...),
|
||
date: str = Query(""),
|
||
period: str = Query("daily")
|
||
):
|
||
"""走势分析:右键K线条形时调用,分析暴涨/暴跌原因"""
|
||
return ai.trend_analysis(symbol, date, period)
|
||
|
||
|
||
# ============ 可回溯:信号历史胜率 + 实测准确率 ============
|
||
@app.get("/api/ai/signal_stats")
|
||
def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)):
|
||
return {"horizon": horizon, "stats": sig.get_stats(horizon)}
|
||
|
||
|
||
@app.post("/api/ai/signal_stats/compute")
|
||
def ai_signal_stats_compute(sample: int = Query(500, ge=50, le=4000), horizon: int = Query(5, ge=1, le=20)):
|
||
return scheduler.trigger_signal_stats_async(sample, horizon)
|
||
|
||
|
||
@app.get("/api/ai/accuracy")
|
||
def ai_accuracy():
|
||
return sig.accuracy()
|
||
|
||
|
||
@app.post("/api/ai/accuracy/verify")
|
||
def ai_accuracy_verify():
|
||
return sig.verify_predictions()
|
||
|
||
|
||
# ============ AI 自动复盘日报 ============
|
||
@app.get("/api/report/daily")
|
||
def report_daily(date: str = Query(None)):
|
||
return rpt.get_by_date(date) if date else rpt.latest()
|
||
|
||
|
||
@app.get("/api/report/history")
|
||
def report_history(limit: int = Query(30, ge=1, le=120)):
|
||
return rpt.history(limit)
|
||
|
||
|
||
@app.post("/api/report/generate")
|
||
def report_generate(date: str = Query(None), push: bool = Query(False)):
|
||
return rpt.generate(date, push=push)
|
||
|
||
|
||
# ============ 交易日志 & 组合 ============
|
||
class TradeIn(BaseModel):
|
||
code: str
|
||
name: str = ""
|
||
side: str = "buy"
|
||
price: float
|
||
qty: int
|
||
fee: float = 0.0
|
||
date: str = ""
|
||
reason: str = ""
|
||
emotion: str = ""
|
||
|
||
|
||
@app.get("/api/trades")
|
||
def list_trades():
|
||
with get_session() as s:
|
||
rows = s.execute(select(Trade).order_by(Trade.date.desc(), Trade.id.desc())).scalars().all()
|
||
names = {}
|
||
return {"list": [{"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
|
||
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
|
||
"reason": t.reason, "emotion": t.emotion} for t in rows]}
|
||
|
||
|
||
@app.post("/api/trades")
|
||
def add_trade(t: TradeIn):
|
||
d = dt.date.fromisoformat(t.date) if t.date else dt.date.today()
|
||
name = t.name
|
||
if not name:
|
||
with get_session() as s:
|
||
sec = s.get(Security, t.code)
|
||
name = sec.name if sec else t.code
|
||
with get_session() as s:
|
||
row = Trade(date=d, code=t.code, name=name, side=t.side, price=t.price,
|
||
qty=t.qty, fee=t.fee, reason=t.reason, emotion=t.emotion)
|
||
s.add(row); s.commit()
|
||
return {"ok": True, "id": row.id}
|
||
|
||
|
||
@app.delete("/api/trades/{tid}")
|
||
def del_trade(tid: int):
|
||
with get_session() as s:
|
||
row = s.get(Trade, tid)
|
||
if row:
|
||
s.delete(row); s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
@app.get("/api/portfolio")
|
||
def get_portfolio():
|
||
return pf.compute()
|
||
|
||
|
||
@app.get("/api/portfolio/equity")
|
||
def portfolio_equity():
|
||
return pf.equity_curve()
|
||
|
||
|
||
# ============ 持仓成本可视化增强 ============
|
||
@app.get("/api/portfolio/cost_line/{code}")
|
||
def get_cost_line(code: str):
|
||
"""获取个股持仓成本线(用于K线图标注)"""
|
||
return pc.get_position_cost_lines(code)
|
||
|
||
@app.get("/api/portfolio/cost_distribution")
|
||
def get_cost_distribution():
|
||
"""获取持仓成本分布(盈亏区间图)"""
|
||
return pc.get_position_cost_distribution()
|
||
|
||
class EstimateCostRequest(BaseModel):
|
||
code: str
|
||
price: float
|
||
qty: int
|
||
side: str = "buy"
|
||
|
||
@app.post("/api/portfolio/estimate_cost")
|
||
def estimate_cost(req: EstimateCostRequest):
|
||
"""估算交易成本(下单前预估)"""
|
||
return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side)
|
||
|
||
@app.get("/api/portfolio/cost_breakdown/{code}")
|
||
def get_cost_breakdown(code: str):
|
||
"""获取持仓的详细成本拆解"""
|
||
return pc.get_cost_breakdown_for_position(code)
|
||
|
||
|
||
# ============ 交易日历与关键事件 ============
|
||
@app.get("/api/calendar/events")
|
||
def calendar_events(days: int = Query(30, ge=7, le=90)):
|
||
"""获取所有即将到来的关键事件(综合视图)"""
|
||
return cal.get_all_upcoming_events(days)
|
||
|
||
@app.get("/api/calendar/dividends")
|
||
def calendar_dividends(days: int = Query(30, ge=7, le=90)):
|
||
"""除权除息日历(持仓股优先)"""
|
||
return cal.get_upcoming_dividends(days)
|
||
|
||
@app.get("/api/calendar/unlock")
|
||
def calendar_unlock(days: int = Query(90, ge=7, le=180)):
|
||
"""限售解禁日历"""
|
||
return cal.get_unlock_calendar(days)
|
||
|
||
@app.get("/api/calendar/earnings")
|
||
def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)):
|
||
"""财报披露日历"""
|
||
return cal.get_earnings_calendar(days, holding_only)
|
||
|
||
@app.post("/api/calendar/check_alerts")
|
||
def calendar_check_alerts(current_user = Depends(require_admin)):
|
||
"""手动触发日历事件预警推送"""
|
||
return cal.check_and_push_calendar_alerts()
|
||
|
||
|
||
# ============ 数据修正与回填增强 ============
|
||
class UpdateQuoteRequest(BaseModel):
|
||
open: Optional[float] = None
|
||
high: Optional[float] = None
|
||
low: Optional[float] = None
|
||
close: Optional[float] = None
|
||
volume: Optional[int] = None
|
||
amount: Optional[float] = None
|
||
|
||
@app.delete("/api/data/quote/{code}/{date}")
|
||
def delete_quote(code: str, date: str, current_user = Depends(require_admin)):
|
||
"""删除指定股票指定日期的日线"""
|
||
return dm.delete_quote(code, date)
|
||
|
||
@app.put("/api/data/quote/{code}/{date}")
|
||
def update_quote(code: str, date: str, req: UpdateQuoteRequest,
|
||
current_user = Depends(require_admin)):
|
||
"""修正指定日线数据"""
|
||
return dm.update_quote(code, date, req.model_dump(exclude_none=True))
|
||
|
||
class DeleteRangeRequest(BaseModel):
|
||
start: str
|
||
end: str
|
||
|
||
@app.delete("/api/data/quotes/{code}/range")
|
||
def delete_quotes_range(code: str, req: DeleteRangeRequest,
|
||
current_user = Depends(require_admin)):
|
||
"""删除指定股票日期范围内的日线数据"""
|
||
return dm.delete_quotes_range(code, req.start, req.end)
|
||
|
||
@app.post("/api/data/refetch/{code}")
|
||
def refetch_quote(code: str, days: int = Query(60, ge=5, le=500),
|
||
current_user = Depends(require_admin)):
|
||
"""重新抓取指定股票日线(覆盖更新)"""
|
||
return dm.refetch_quote(code, days)
|
||
|
||
@app.get("/api/data/integrity")
|
||
def check_integrity(days: int = Query(30, ge=7, le=90),
|
||
current_user = Depends(require_admin)):
|
||
"""数据完整性检查"""
|
||
return dm.check_data_integrity(days=days)
|
||
|
||
@app.post("/api/data/auto_fix")
|
||
def auto_fix_missing(limit: int = Query(50, ge=10, le=200),
|
||
current_user = Depends(require_admin)):
|
||
"""自动补齐缺失数据"""
|
||
t = __import__("threading").Thread(
|
||
target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True
|
||
)
|
||
t.start()
|
||
return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"}
|
||
|
||
@app.get("/api/data/refill_progress")
|
||
def refill_progress(task_id: str = Query("default")):
|
||
"""获取回填进度"""
|
||
return dm.get_refill_progress(task_id)
|
||
|
||
@app.post("/api/data/refill_resume")
|
||
def refill_resume(days: int = Query(250, ge=30, le=1000),
|
||
task_id: str = Query("default"),
|
||
current_user = Depends(require_admin)):
|
||
"""带断点续传的全市场回填(后台执行)"""
|
||
import threading
|
||
t = threading.Thread(
|
||
target=dm.start_refill_with_resume,
|
||
kwargs={"days": days, "task_id": task_id},
|
||
daemon=True
|
||
)
|
||
t.start()
|
||
return {"ok": True, "msg": f"已启动断点续传回填,天数={days},任务ID={task_id}"}
|
||
|
||
@app.delete("/api/data/refill_progress")
|
||
def clear_refill_progress(task_id: str = Query("default"),
|
||
current_user = Depends(require_admin)):
|
||
"""清除回填进度(从头开始)"""
|
||
return dm.clear_refill_progress(task_id)
|
||
|
||
@app.get("/api/data/quality_report")
|
||
def data_quality_report(current_user = Depends(require_admin)):
|
||
"""数据质量报告"""
|
||
return dm.get_data_quality_report()
|
||
|
||
|
||
@app.get("/api/portfolio/attribution")
|
||
def portfolio_attribution():
|
||
"""持仓归因分析"""
|
||
return attrib.analyze_attribution()
|
||
|
||
|
||
# ============ AI 对话式分析 ============
|
||
class ChatRequest(BaseModel):
|
||
session_id: str
|
||
message: str
|
||
|
||
|
||
@app.post("/api/chat")
|
||
def chat(req: ChatRequest):
|
||
"""AI对话"""
|
||
return ai_chat.chat(req.session_id, req.message)
|
||
|
||
|
||
@app.delete("/api/chat/{session_id}")
|
||
def clear_chat(session_id: str):
|
||
"""清空会话"""
|
||
ai_chat.clear_session(session_id)
|
||
return {"ok": True}
|
||
|
||
|
||
@app.get("/api/chat/{session_id}/history")
|
||
def chat_history(session_id: str):
|
||
"""获取会话历史"""
|
||
return {"ok": True, "messages": ai_chat.get_session_history(session_id)}
|
||
|
||
|
||
# ============ 社区情绪监控 ============
|
||
@app.post("/api/sentiment/collect")
|
||
def sentiment_collect(limit: int = Query(50, ge=10, le=200)):
|
||
"""采集社区帖子"""
|
||
return sentiment.collect_posts(limit)
|
||
|
||
|
||
@app.get("/api/sentiment/index")
|
||
def sentiment_index(date: Optional[str] = None):
|
||
"""获取情绪指数"""
|
||
d = dt.date.fromisoformat(date) if date else None
|
||
return sentiment.calculate_sentiment_index(d)
|
||
|
||
|
||
@app.get("/api/sentiment/hot_stocks")
|
||
def sentiment_hot_stocks(days: int = Query(1, ge=1, le=7), limit: int = Query(20, le=50)):
|
||
"""热议股票排行"""
|
||
return sentiment.get_hot_stocks(days, limit)
|
||
|
||
|
||
@app.get("/api/sentiment/history")
|
||
def sentiment_history(days: int = Query(30, ge=7, le=90)):
|
||
"""情绪指数历史"""
|
||
return sentiment.get_sentiment_history(days)
|
||
|
||
|
||
@app.get("/api/sentiment/correlation")
|
||
def sentiment_correlation(code: str = Query(...), days: int = Query(60, ge=20, le=180)):
|
||
"""情绪与股价相关性"""
|
||
return sentiment.analyze_sentiment_correlation(code, days)
|
||
|
||
|
||
@app.get("/api/sentiment/wordcloud")
|
||
def sentiment_wordcloud(days: int = Query(7, ge=1, le=30), top_n: int = Query(50, le=100)):
|
||
"""关键词云"""
|
||
return sentiment.get_keyword_cloud(days, top_n)
|
||
|
||
|
||
# ============ 事件驱动策略 ============
|
||
@app.post("/api/events/seed")
|
||
def events_seed():
|
||
"""生成示例事件数据"""
|
||
return events.seed_sample_events()
|
||
|
||
|
||
@app.get("/api/events/earnings/pattern")
|
||
def earnings_pattern(days_before: int = Query(5, ge=1, le=10), days_after: int = Query(10, ge=5, le=30)):
|
||
"""财报发布前后统计规律"""
|
||
return events.analyze_earnings_pattern(days_before, days_after)
|
||
|
||
|
||
@app.get("/api/events/insider")
|
||
def insider_trading(code: Optional[str] = None, days: int = Query(180, ge=30, le=365)):
|
||
"""高管增减持跟踪"""
|
||
return events.track_insider_trading(code, days)
|
||
|
||
|
||
@app.get("/api/events/unlock")
|
||
def unlock_impact(days: int = Query(90, ge=30, le=180)):
|
||
"""限售解禁影响分析"""
|
||
return events.analyze_unlock_impact(days)
|
||
|
||
|
||
@app.get("/api/events/policy")
|
||
def policy_events(sector: Optional[str] = None, days: int = Query(180, ge=30, le=365)):
|
||
"""行业政策事件"""
|
||
return events.get_policy_events(sector, days)
|
||
|
||
|
||
class EventSelectorRequest(BaseModel):
|
||
event_types: List[str]
|
||
days: int = 30
|
||
|
||
|
||
@app.post("/api/events/selector")
|
||
def event_selector(req: EventSelectorRequest):
|
||
"""事件驱动选股"""
|
||
return events.event_driven_selector(req.event_types, req.days)
|
||
|
||
|
||
# ============ 财报深度解读 ============
|
||
@app.post("/api/financial/seed")
|
||
def financial_seed():
|
||
"""生成示例财报数据"""
|
||
return fin.seed_sample_reports()
|
||
|
||
|
||
@app.get("/api/financial/trend")
|
||
def financial_trend(code: str = Query(...), periods: int = Query(8, ge=4, le=16)):
|
||
"""财报关键指标趋势"""
|
||
return fin.get_report_trend(code, periods)
|
||
|
||
|
||
@app.get("/api/financial/summary")
|
||
def financial_summary(code: str = Query(...)):
|
||
"""AI财报摘要"""
|
||
return fin.generate_ai_summary(code)
|
||
|
||
|
||
@app.get("/api/financial/compare")
|
||
def financial_compare(code: str = Query(...), sector: Optional[str] = None):
|
||
"""同行对比"""
|
||
return fin.compare_with_peers(code, sector)
|
||
|
||
|
||
@app.get("/api/financial/warnings")
|
||
def financial_warnings(code: str = Query(...)):
|
||
"""财报异常预警"""
|
||
return fin.detect_abnormalities(code)
|
||
|
||
|
||
@app.get("/api/financial/calendar")
|
||
def financial_calendar(days: int = Query(30, ge=7, le=90)):
|
||
"""财报发布日历"""
|
||
return fin.get_report_calendar(days)
|
||
|
||
|
||
@app.get("/api/financial/rankings")
|
||
def financial_rankings(metric: str = Query("roe"), limit: int = Query(20, le=50)):
|
||
"""财报排行榜"""
|
||
return fin.get_top_reports(metric, limit)
|
||
|
||
|
||
# ============ 涨跌停分析 ============
|
||
@app.get("/api/limit/stocks")
|
||
def limit_stocks(date: Optional[str] = None, limit_type: str = Query("up")):
|
||
"""获取涨停/跌停股票"""
|
||
d = dt.date.fromisoformat(date) if date else None
|
||
return limit_up.get_limit_stocks(d, limit_type)
|
||
|
||
|
||
@app.get("/api/limit/consecutive")
|
||
def consecutive_limits(days: int = Query(10, ge=5, le=30)):
|
||
"""连板股追踪"""
|
||
return limit_up.track_consecutive_limits(days)
|
||
|
||
|
||
@app.get("/api/limit/break_rate")
|
||
def limit_break_rate(days: int = Query(60, ge=30, le=180)):
|
||
"""炸板率统计"""
|
||
return limit_up.analyze_limit_break_rate(days)
|
||
|
||
|
||
@app.get("/api/limit/squad")
|
||
def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5, ge=3, le=10)):
|
||
"""涨停敢死队排行"""
|
||
return limit_up.get_limit_squad_rankings(days, min_limits)
|
||
|
||
@app.get("/api/limit/consecutive_calendar")
|
||
def consecutive_calendar(days: int = Query(60, ge=20, le=120)):
|
||
"""连板日历:记录连板历史,分析几进几出规律"""
|
||
return limit_up.get_consecutive_calendar(days)
|
||
|
||
@app.get("/api/limit/post_break")
|
||
def post_break_performance(days: int = Query(90, ge=30, le=180)):
|
||
"""炸板后 1/3/5 日走势统计"""
|
||
return limit_up.analyze_post_break_performance(days)
|
||
|
||
@app.get("/api/limit/reasons")
|
||
def limit_reasons(date: Optional[str] = None):
|
||
"""涨停原因分类(情绪/题材/业绩/政策等)"""
|
||
d = dt.date.fromisoformat(date) if date else None
|
||
return limit_up.classify_limit_reasons(d)
|
||
|
||
|
||
# ============ 推送通知 ============
|
||
@app.get("/api/notify/status")
|
||
def notify_status():
|
||
return {"channels": notifier.channels_status(), "enabled": notifier.any_enabled()}
|
||
|
||
|
||
@app.post("/api/notify/test")
|
||
def notify_test():
|
||
if not notifier.any_enabled():
|
||
return {"ok": False, "msg": "未配置任何推送渠道,请在 backend/.env 配置后重启"}
|
||
res = notifier.notify("【Blackdata】推送测试", "这是一条来自Blackdata股票终端的测试通知,收到即表示推送通道正常。")
|
||
return {"ok": True, "result": res}
|
||
|
||
|
||
# ============ 智能预警 ============
|
||
class AlertIn(BaseModel):
|
||
code: str
|
||
kind: str = "price_above"
|
||
threshold: float
|
||
note: str = ""
|
||
|
||
|
||
@app.get("/api/alerts")
|
||
def list_alerts():
|
||
with get_session() as s:
|
||
rows = s.execute(select(AlertRule).order_by(AlertRule.id.desc())).scalars().all()
|
||
return {"list": [{"id": r.id, "code": r.code, "name": r.name, "kind": r.kind,
|
||
"threshold": r.threshold, "status": r.status, "note": r.note,
|
||
"last_value": r.last_value,
|
||
"triggered_at": r.triggered_at.strftime("%m-%d %H:%M") if r.triggered_at else ""}
|
||
for r in rows]}
|
||
|
||
|
||
@app.post("/api/alerts")
|
||
def add_alert(a: AlertIn):
|
||
with get_session() as s:
|
||
sec = s.get(Security, a.code)
|
||
name = sec.name if sec else a.code
|
||
row = AlertRule(code=a.code, name=name, kind=a.kind, threshold=a.threshold, note=a.note)
|
||
s.add(row); s.commit()
|
||
return {"ok": True, "id": row.id}
|
||
|
||
|
||
@app.delete("/api/alerts/{aid}")
|
||
def del_alert(aid: int):
|
||
with get_session() as s:
|
||
row = s.get(AlertRule, aid)
|
||
if row:
|
||
s.delete(row); s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
@app.post("/api/alerts/{aid}/reactivate")
|
||
def reactivate_alert(aid: int):
|
||
with get_session() as s:
|
||
row = s.get(AlertRule, aid)
|
||
if row:
|
||
row.status = "active"; row.triggered_at = None; s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
@app.post("/api/alerts/check")
|
||
def manual_check():
|
||
return al.check_alerts()
|
||
|
||
|
||
@app.get("/api/alerts/events")
|
||
def alert_events(unread_only: bool = Query(False), limit: int = Query(30, le=100)):
|
||
with get_session() as s:
|
||
stmt = select(AlertEvent).order_by(AlertEvent.id.desc())
|
||
if unread_only:
|
||
stmt = stmt.where(AlertEvent.read.is_(False))
|
||
rows = s.execute(stmt.limit(limit)).scalars().all()
|
||
unread = s.execute(select(func.count()).select_from(AlertEvent).where(AlertEvent.read.is_(False))).scalar() or 0
|
||
return {"unread": unread, "list": [{"id": e.id, "code": e.code, "name": e.name,
|
||
"message": e.message, "time": e.created_at.strftime("%m-%d %H:%M:%S") if e.created_at else ""}
|
||
for e in rows]}
|
||
|
||
|
||
@app.post("/api/alerts/events/read")
|
||
def mark_events_read():
|
||
with get_session() as s:
|
||
for e in s.execute(select(AlertEvent).where(AlertEvent.read.is_(False))).scalars():
|
||
e.read = True
|
||
s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
# ============ 资讯中心 ============
|
||
@app.get("/api/news")
|
||
def news(limit: int = Query(40, le=100)):
|
||
return svc.get_news(limit)
|
||
|
||
|
||
@app.get("/api/news/stock")
|
||
def news_stock(code: str = Query(...)):
|
||
return svc.get_stock_news(code)
|
||
|
||
|
||
@app.get("/api/news/watch")
|
||
def news_watch():
|
||
codes = load_watch()[:6]
|
||
out = []
|
||
for c in codes:
|
||
r = svc.get_stock_news(c, limit=4)
|
||
for x in r["list"]:
|
||
x["code"] = c
|
||
out.append(x)
|
||
out.sort(key=lambda x: x["time"], reverse=True)
|
||
return {"list": out[:40]}
|
||
|
||
|
||
class NewsAI(BaseModel):
|
||
title: str
|
||
content: str = ""
|
||
|
||
|
||
@app.post("/api/news/ai")
|
||
def news_ai(n: NewsAI):
|
||
text_in = (n.title + "。" + n.content).strip()
|
||
senti, kw = svc.judge_sentiment(text_in)
|
||
if llm.enabled():
|
||
try:
|
||
prompt = ("请分析下面这条财经资讯:\n"
|
||
"1) 一句话摘要;2) 利好/利空/中性判断及理由;3) 可能受影响的板块或个股方向。120字内。\n\n"
|
||
+ text_in[:1200])
|
||
text = llm.ask(prompt, temperature=0.3, max_tokens=400)
|
||
return {"ok": True, "source": "llm", "sentiment": senti, "text": text}
|
||
except Exception:
|
||
pass
|
||
return {"ok": True, "source": "rule", "sentiment": senti,
|
||
"text": f"判断:{senti}(关键词:{'、'.join(kw) or '无'})。摘要:{text_in[:80]}…\n(配置大模型后可获得更深入的关联分析)"}
|
||
|
||
|
||
# ============ 盘中实时监控雷达 ============
|
||
@app.get("/api/radar/status")
|
||
def radar_status():
|
||
"""雷达状态。"""
|
||
return {"trading_time": radar._is_trading_time()}
|
||
|
||
|
||
@app.post("/api/radar/scan")
|
||
def radar_scan():
|
||
"""手动触发异动扫描。"""
|
||
return radar.scan_all()
|
||
|
||
|
||
@app.get("/api/radar/events")
|
||
def radar_events(hours: int = Query(2, ge=1, le=24), limit: int = Query(50, le=200)):
|
||
"""获取最近的异动事件。"""
|
||
return {"list": radar.get_recent_events(hours, limit)}
|
||
|
||
|
||
@app.post("/api/radar/notify")
|
||
def radar_notify():
|
||
"""推送未通知的异动。"""
|
||
return radar.notify_events()
|
||
|
||
|
||
@app.get("/api/radar/stats")
|
||
def radar_stats(date: str = Query(None)):
|
||
"""异动统计。"""
|
||
d = dt.date.fromisoformat(date) if date else None
|
||
return radar.get_statistics(d)
|
||
|
||
|
||
# ============ 板块轮动分析 ============
|
||
@app.get("/api/sector/trend")
|
||
def sector_trend(days: int = Query(20, ge=5, le=60), top_n: int = Query(15, le=30)):
|
||
"""板块强弱趋势"""
|
||
return sector.get_sector_trend(days, top_n)
|
||
|
||
|
||
@app.get("/api/sector/flow")
|
||
def sector_flow(days: int = Query(5, ge=1, le=20)):
|
||
"""资金流向分析"""
|
||
return sector.analyze_fund_flow(days)
|
||
|
||
|
||
@app.get("/api/sector/lifecycle")
|
||
def sector_lifecycle(name: str = Query(...), days: int = Query(60, ge=20, le=120)):
|
||
"""板块生命周期"""
|
||
return sector.analyze_lifecycle(name, days)
|
||
|
||
|
||
@app.get("/api/sector/leaders")
|
||
def sector_leaders(name: str = Query(...), days: int = Query(20, ge=5, le=60), limit: int = Query(10, le=30)):
|
||
"""龙头股识别"""
|
||
return sector.identify_leaders(name, days, limit)
|
||
|
||
|
||
@app.get("/api/sector/correlation")
|
||
def sector_correlation(days: int = Query(60, ge=20, le=120), top_n: int = Query(20, le=30)):
|
||
"""板块联动性分析"""
|
||
return sector.analyze_correlation(days, top_n)
|
||
|
||
|
||
@app.get("/api/sector/summary")
|
||
def sector_summary():
|
||
"""板块轮动摘要"""
|
||
return sector.get_rotation_summary()
|
||
|
||
|
||
# ============ 智能选股增强 ============
|
||
@app.get("/api/selector/fields")
|
||
def selector_fields():
|
||
"""获取可用字段"""
|
||
return {"ok": True, "fields": selector.get_available_fields()}
|
||
|
||
|
||
@app.get("/api/selector/presets")
|
||
def selector_presets():
|
||
"""获取预设策略"""
|
||
return {"ok": True, "presets": selector.get_preset_strategies()}
|
||
|
||
|
||
class SelectorRequest(BaseModel):
|
||
strategy: Dict[str, Any]
|
||
date: Optional[str] = None
|
||
|
||
|
||
@app.post("/api/selector/run")
|
||
def selector_run(req: SelectorRequest):
|
||
"""执行选股"""
|
||
try:
|
||
strategy = selector.Strategy.from_dict(req.strategy)
|
||
date = dt.date.fromisoformat(req.date) if req.date else None
|
||
return selector.run_selector(strategy, date)
|
||
except Exception as e:
|
||
return {"ok": False, "msg": str(e)}
|
||
|
||
|
||
@app.post("/api/selector/backtest")
|
||
def selector_backtest(req: SelectorRequest, days: int = Query(60, ge=20, le=250)):
|
||
"""选股策略回测"""
|
||
try:
|
||
strategy = selector.Strategy.from_dict(req.strategy)
|
||
return selector.backtest_selector(strategy, days)
|
||
except Exception as e:
|
||
return {"ok": False, "msg": str(e)}
|
||
|
||
|
||
class CompareRequest(BaseModel):
|
||
strategy: Dict[str, Any]
|
||
date1: str
|
||
date2: str
|
||
|
||
|
||
@app.post("/api/selector/compare")
|
||
def selector_compare(req: CompareRequest):
|
||
"""对比选股结果"""
|
||
try:
|
||
strategy = selector.Strategy.from_dict(req.strategy)
|
||
date1 = dt.date.fromisoformat(req.date1)
|
||
date2 = dt.date.fromisoformat(req.date2)
|
||
return selector.compare_results(date1, date2, strategy)
|
||
except Exception as e:
|
||
return {"ok": False, "msg": str(e)}
|
||
|
||
|
||
@app.get("/api/selector/strategies")
|
||
def list_strategies():
|
||
"""获取保存的策略列表"""
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(SelectorStrategy).order_by(SelectorStrategy.updated_at.desc())
|
||
).scalars().all()
|
||
return {
|
||
"ok": True,
|
||
"strategies": [{
|
||
"id": r.id,
|
||
"name": r.name,
|
||
"description": r.description,
|
||
"is_preset": r.is_preset,
|
||
"created_at": r.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"updated_at": r.updated_at.strftime("%Y-%m-%d %H:%M:%S")
|
||
} for r in rows]
|
||
}
|
||
|
||
|
||
class SaveStrategyRequest(BaseModel):
|
||
name: str
|
||
description: str = ""
|
||
strategy: Dict[str, Any]
|
||
|
||
|
||
@app.post("/api/selector/strategies")
|
||
def save_strategy(req: SaveStrategyRequest):
|
||
"""保存策略"""
|
||
try:
|
||
strategy = selector.Strategy.from_dict(req.strategy)
|
||
with get_session() as s:
|
||
record = SelectorStrategy(
|
||
name=req.name,
|
||
description=req.description,
|
||
strategy_json=strategy.to_json()
|
||
)
|
||
s.add(record)
|
||
s.commit()
|
||
return {"ok": True, "id": record.id}
|
||
except Exception as e:
|
||
return {"ok": False, "msg": str(e)}
|
||
|
||
|
||
@app.get("/api/selector/strategies/{sid}")
|
||
def get_strategy(sid: int):
|
||
"""获取策略详情"""
|
||
with get_session() as s:
|
||
record = s.get(SelectorStrategy, sid)
|
||
if not record:
|
||
return {"ok": False, "msg": "策略不存在"}
|
||
return {
|
||
"ok": True,
|
||
"id": record.id,
|
||
"name": record.name,
|
||
"description": record.description,
|
||
"strategy": json.loads(record.strategy_json)
|
||
}
|
||
|
||
|
||
@app.delete("/api/selector/strategies/{sid}")
|
||
def delete_strategy(sid: int):
|
||
"""删除策略"""
|
||
with get_session() as s:
|
||
record = s.get(SelectorStrategy, sid)
|
||
if record:
|
||
s.delete(record)
|
||
s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
@app.get("/api/selector/alerts")
|
||
def list_selector_alerts():
|
||
"""获取选股预警列表"""
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(SelectorAlert).order_by(SelectorAlert.id.desc())
|
||
).scalars().all()
|
||
return {
|
||
"ok": True,
|
||
"alerts": [{
|
||
"id": r.id,
|
||
"strategy_id": r.strategy_id,
|
||
"strategy_name": r.strategy_name,
|
||
"status": r.status,
|
||
"last_checked": r.last_checked.strftime("%m-%d %H:%M") if r.last_checked else "",
|
||
"last_count": r.last_count
|
||
} for r in rows]
|
||
}
|
||
|
||
|
||
class CreateAlertRequest(BaseModel):
|
||
strategy_id: int
|
||
strategy_name: str
|
||
|
||
|
||
@app.post("/api/selector/alerts")
|
||
def create_selector_alert(req: CreateAlertRequest):
|
||
"""创建选股预警"""
|
||
with get_session() as s:
|
||
record = SelectorAlert(
|
||
strategy_id=req.strategy_id,
|
||
strategy_name=req.strategy_name
|
||
)
|
||
s.add(record)
|
||
s.commit()
|
||
return {"ok": True, "id": record.id}
|
||
|
||
|
||
@app.delete("/api/selector/alerts/{aid}")
|
||
def delete_selector_alert(aid: int):
|
||
"""删除选股预警"""
|
||
with get_session() as s:
|
||
record = s.get(SelectorAlert, aid)
|
||
if record:
|
||
s.delete(record)
|
||
s.commit()
|
||
return {"ok": True}
|
||
|
||
|
||
# ============ 模拟盘 ============
|
||
|
||
class PaperAccountIn(BaseModel):
|
||
name: str
|
||
initial_cash: float = 1_000_000.0
|
||
|
||
|
||
@app.get("/api/paper/accounts")
|
||
def paper_list_accounts():
|
||
return {"ok": True, "accounts": paper.list_accounts()}
|
||
|
||
|
||
@app.post("/api/paper/accounts")
|
||
def paper_create_account(req: PaperAccountIn):
|
||
return paper.create_account(req.name, req.initial_cash)
|
||
|
||
|
||
@app.post("/api/paper/accounts/{account_id}/reset")
|
||
def paper_reset_account(account_id: int, initial_cash: Optional[float] = None):
|
||
return paper.reset_account(account_id, initial_cash)
|
||
|
||
|
||
class PaperOrderIn(BaseModel):
|
||
code: str
|
||
side: str # buy / sell
|
||
qty: int
|
||
price: Optional[float] = None
|
||
reason: str = ""
|
||
|
||
|
||
@app.post("/api/paper/accounts/{account_id}/order")
|
||
def paper_place_order(account_id: int, req: PaperOrderIn):
|
||
return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason)
|
||
|
||
|
||
@app.get("/api/paper/accounts/{account_id}/portfolio")
|
||
def paper_get_portfolio(account_id: int):
|
||
return paper.get_portfolio(account_id)
|
||
|
||
|
||
@app.get("/api/paper/accounts/{account_id}/trades")
|
||
def paper_get_trades(account_id: int, limit: int = Query(100, le=500)):
|
||
return {"ok": True, "trades": paper.get_trades(account_id, limit)}
|
||
|
||
|
||
# ============ 静态前端 ============
|
||
FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype")
|
||
if os.path.isdir(FRONTEND_DIR):
|
||
app.mount("/", StaticFiles(directory=FRONTEND_DIR, html=True), name="frontend")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
|