Files
stock_cursor_v0/backend/paper_trading.py
2026-06-15 01:26:39 +08:00

222 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""模拟盘核心逻辑。
- 多账户支持(默认账户 id=1
- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金
- 持仓计算使用移动加权平均成本法
"""
from __future__ import annotations
import datetime as dt
from collections import defaultdict
from sqlalchemy import select
from db import get_session
from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote
DEFAULT_FEE_RATE = 0.0003
def _get_price(code: str) -> float | None:
with get_session() as s:
m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none()
if m:
return float(m)
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == code)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
return float(row) if row else None
# ── 账户管理 ──────────────────────────────────────────────
def ensure_default_account():
"""确保默认账户id=1存在启动时调用。"""
with get_session() as s:
if not s.get(PaperAccount, 1):
s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0))
s.commit()
def list_accounts() -> list[dict]:
with get_session() as s:
rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all()
return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash,
"cash": round(r.cash, 2), "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]
def create_account(name: str, initial_cash: float) -> dict:
with get_session() as s:
acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash)
s.add(acc)
s.commit()
return {"ok": True, "id": acc.id}
def reset_account(account_id: int, initial_cash: float | None = None) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
if initial_cash is not None:
acc.initial_cash = initial_cash
acc.cash = acc.initial_cash
for t in s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
).scalars():
s.delete(t)
s.commit()
return {"ok": True, "msg": "账户已重置"}
# ── 持仓计算(内部)────────────────────────────────────────
def _calc_holdings_in_session(account_id: int, s) -> list[dict]:
trades = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.date, PaperTrade.id)
).scalars().all()
pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
for t in trades:
p = pos[t.code]
p["name"] = t.name or p["name"]
if t.side == "buy":
p["cost"] += t.price * t.qty + t.fee
p["qty"] += t.qty
else:
if p["qty"] > 0:
avg = p["cost"] / p["qty"]
qty = min(t.qty, p["qty"])
p["cost"] -= avg * qty
p["qty"] -= qty
return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]}
for c, v in pos.items() if v["qty"] > 0]
# ── 下单 ──────────────────────────────────────────────────
def place_order(account_id: int, code: str, side: str, qty: int,
price: float | None = None, reason: str = "") -> dict:
if qty <= 0:
return {"ok": False, "msg": "数量必须大于 0"}
if side not in ("buy", "sell"):
return {"ok": False, "msg": "side 只能是 buy 或 sell"}
exec_price = price or _get_price(code)
if not exec_price:
return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"}
fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2)
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
sec = s.get(Security, code)
name = sec.name if sec else code
if side == "buy":
cost = exec_price * qty + fee
if acc.cash < cost:
return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"}
cash_before = acc.cash
acc.cash -= cost
else:
holdings = _calc_holdings_in_session(account_id, s)
pos = next((h for h in holdings if h["code"] == code), None)
avail = pos["qty"] if pos else 0
if avail < qty:
return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty}"}
cash_before = acc.cash
acc.cash += exec_price * qty - fee
trade = PaperTrade(
account_id=account_id,
date=dt.date.today(),
code=code, name=name, side=side,
price=exec_price, qty=qty, fee=fee,
cash_before=cash_before, cash_after=acc.cash,
reason=reason,
)
s.add(trade)
s.commit()
return {"ok": True, "id": trade.id, "price": exec_price,
"fee": fee, "cash_after": round(acc.cash, 2)}
# ── 查询接口 ──────────────────────────────────────────────
def get_portfolio(account_id: int) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
cash = acc.cash
initial = acc.initial_cash
holdings_raw = _calc_holdings_in_session(account_id, s)
codes = [h["code"] for h in holdings_raw]
px: dict[str, float] = {}
if codes:
with get_session() as s:
for m in s.execute(
select(StockMetric).where(StockMetric.code.in_(codes))
).scalars():
px[m.code] = m.close
for c in [c for c in codes if c not in px]:
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == c)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
if row:
px[c] = float(row)
holdings, mkt_val = [], 0.0
for h in holdings_raw:
avg = h["cost"] / h["qty"] if h["qty"] else 0.0
cur = px.get(h["code"], avg)
mv = cur * h["qty"]
unreal = (cur - avg) * h["qty"]
mkt_val += mv
holdings.append({
"code": h["code"], "name": h["name"], "qty": h["qty"],
"avg_cost": round(avg, 3), "cur": round(cur, 3),
"market_value": round(mv, 2),
"unrealized": round(unreal, 2),
"unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0,
})
holdings.sort(key=lambda x: x["unrealized"], reverse=True)
total_assets = cash + mkt_val
total_pnl = total_assets - initial
return {
"ok": True,
"account_id": account_id,
"summary": {
"initial_cash": round(initial, 2),
"cash": round(cash, 2),
"market_value": round(mkt_val, 2),
"total_assets": round(total_assets, 2),
"total_pnl": round(total_pnl, 2),
"total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0,
"positions": len(holdings),
},
"holdings": holdings,
}
def get_trades(account_id: int, limit: int = 100) -> list[dict]:
with get_session() as s:
rows = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.id.desc()).limit(limit)
).scalars().all()
return [{
"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
"cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2),
"reason": t.reason,
} for t in rows]