功能细节优化

This commit is contained in:
2026-06-15 01:26:39 +08:00
parent e524a3589a
commit 964c17c200
33 changed files with 6990 additions and 210 deletions

222
backend/paper_trading.py Normal file
View File

@@ -0,0 +1,222 @@
"""模拟盘核心逻辑。
- 多账户支持(默认账户 id=1
- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金
- 持仓计算使用移动加权平均成本法
"""
from __future__ import annotations
import datetime as dt
from collections import defaultdict
from sqlalchemy import select
from db import get_session
from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote
DEFAULT_FEE_RATE = 0.0003
def _get_price(code: str) -> float | None:
with get_session() as s:
m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none()
if m:
return float(m)
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == code)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
return float(row) if row else None
# ── 账户管理 ──────────────────────────────────────────────
def ensure_default_account():
"""确保默认账户id=1存在启动时调用。"""
with get_session() as s:
if not s.get(PaperAccount, 1):
s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0))
s.commit()
def list_accounts() -> list[dict]:
with get_session() as s:
rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all()
return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash,
"cash": round(r.cash, 2), "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]
def create_account(name: str, initial_cash: float) -> dict:
with get_session() as s:
acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash)
s.add(acc)
s.commit()
return {"ok": True, "id": acc.id}
def reset_account(account_id: int, initial_cash: float | None = None) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
if initial_cash is not None:
acc.initial_cash = initial_cash
acc.cash = acc.initial_cash
for t in s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
).scalars():
s.delete(t)
s.commit()
return {"ok": True, "msg": "账户已重置"}
# ── 持仓计算(内部)────────────────────────────────────────
def _calc_holdings_in_session(account_id: int, s) -> list[dict]:
trades = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.date, PaperTrade.id)
).scalars().all()
pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
for t in trades:
p = pos[t.code]
p["name"] = t.name or p["name"]
if t.side == "buy":
p["cost"] += t.price * t.qty + t.fee
p["qty"] += t.qty
else:
if p["qty"] > 0:
avg = p["cost"] / p["qty"]
qty = min(t.qty, p["qty"])
p["cost"] -= avg * qty
p["qty"] -= qty
return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]}
for c, v in pos.items() if v["qty"] > 0]
# ── 下单 ──────────────────────────────────────────────────
def place_order(account_id: int, code: str, side: str, qty: int,
price: float | None = None, reason: str = "") -> dict:
if qty <= 0:
return {"ok": False, "msg": "数量必须大于 0"}
if side not in ("buy", "sell"):
return {"ok": False, "msg": "side 只能是 buy 或 sell"}
exec_price = price or _get_price(code)
if not exec_price:
return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"}
fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2)
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
sec = s.get(Security, code)
name = sec.name if sec else code
if side == "buy":
cost = exec_price * qty + fee
if acc.cash < cost:
return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"}
cash_before = acc.cash
acc.cash -= cost
else:
holdings = _calc_holdings_in_session(account_id, s)
pos = next((h for h in holdings if h["code"] == code), None)
avail = pos["qty"] if pos else 0
if avail < qty:
return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty}"}
cash_before = acc.cash
acc.cash += exec_price * qty - fee
trade = PaperTrade(
account_id=account_id,
date=dt.date.today(),
code=code, name=name, side=side,
price=exec_price, qty=qty, fee=fee,
cash_before=cash_before, cash_after=acc.cash,
reason=reason,
)
s.add(trade)
s.commit()
return {"ok": True, "id": trade.id, "price": exec_price,
"fee": fee, "cash_after": round(acc.cash, 2)}
# ── 查询接口 ──────────────────────────────────────────────
def get_portfolio(account_id: int) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
cash = acc.cash
initial = acc.initial_cash
holdings_raw = _calc_holdings_in_session(account_id, s)
codes = [h["code"] for h in holdings_raw]
px: dict[str, float] = {}
if codes:
with get_session() as s:
for m in s.execute(
select(StockMetric).where(StockMetric.code.in_(codes))
).scalars():
px[m.code] = m.close
for c in [c for c in codes if c not in px]:
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == c)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
if row:
px[c] = float(row)
holdings, mkt_val = [], 0.0
for h in holdings_raw:
avg = h["cost"] / h["qty"] if h["qty"] else 0.0
cur = px.get(h["code"], avg)
mv = cur * h["qty"]
unreal = (cur - avg) * h["qty"]
mkt_val += mv
holdings.append({
"code": h["code"], "name": h["name"], "qty": h["qty"],
"avg_cost": round(avg, 3), "cur": round(cur, 3),
"market_value": round(mv, 2),
"unrealized": round(unreal, 2),
"unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0,
})
holdings.sort(key=lambda x: x["unrealized"], reverse=True)
total_assets = cash + mkt_val
total_pnl = total_assets - initial
return {
"ok": True,
"account_id": account_id,
"summary": {
"initial_cash": round(initial, 2),
"cash": round(cash, 2),
"market_value": round(mkt_val, 2),
"total_assets": round(total_assets, 2),
"total_pnl": round(total_pnl, 2),
"total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0,
"positions": len(holdings),
},
"holdings": holdings,
}
def get_trades(account_id: int, limit: int = 100) -> list[dict]:
with get_session() as s:
rows = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.id.desc()).limit(limit)
).scalars().all()
return [{
"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
"cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2),
"reason": t.reason,
} for t in rows]