"""模拟盘核心逻辑。 - 多账户支持(默认账户 id=1) - 买卖按实时价(或收盘价)撮合,自动扣减/增加现金 - 持仓计算使用移动加权平均成本法 """ from __future__ import annotations import datetime as dt from collections import defaultdict from sqlalchemy import select from db import get_session from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote DEFAULT_FEE_RATE = 0.0003 def _get_price(code: str) -> float | None: with get_session() as s: m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none() if m: return float(m) row = s.execute( select(DailyQuote.close).where(DailyQuote.code == code) .order_by(DailyQuote.date.desc()).limit(1) ).scalar_one_or_none() return float(row) if row else None # ── 账户管理 ────────────────────────────────────────────── def ensure_default_account(): """确保默认账户(id=1)存在,启动时调用。""" with get_session() as s: if not s.get(PaperAccount, 1): s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0)) s.commit() def list_accounts() -> list[dict]: with get_session() as s: rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all() return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash, "cash": round(r.cash, 2), "is_active": r.is_active, "created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows] def create_account(name: str, initial_cash: float) -> dict: with get_session() as s: acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash) s.add(acc) s.commit() return {"ok": True, "id": acc.id} def reset_account(account_id: int, initial_cash: float | None = None) -> dict: with get_session() as s: acc = s.get(PaperAccount, account_id) if not acc: return {"ok": False, "msg": "账户不存在"} if initial_cash is not None: acc.initial_cash = initial_cash acc.cash = acc.initial_cash for t in s.execute( select(PaperTrade).where(PaperTrade.account_id == account_id) ).scalars(): s.delete(t) s.commit() return {"ok": True, "msg": "账户已重置"} # ── 持仓计算(内部)──────────────────────────────────────── def _calc_holdings_in_session(account_id: int, s) -> list[dict]: trades = s.execute( select(PaperTrade).where(PaperTrade.account_id == account_id) .order_by(PaperTrade.date, PaperTrade.id) ).scalars().all() pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""}) for t in trades: p = pos[t.code] p["name"] = t.name or p["name"] if t.side == "buy": p["cost"] += t.price * t.qty + t.fee p["qty"] += t.qty else: if p["qty"] > 0: avg = p["cost"] / p["qty"] qty = min(t.qty, p["qty"]) p["cost"] -= avg * qty p["qty"] -= qty return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]} for c, v in pos.items() if v["qty"] > 0] # ── 下单 ────────────────────────────────────────────────── def place_order(account_id: int, code: str, side: str, qty: int, price: float | None = None, reason: str = "") -> dict: if qty <= 0: return {"ok": False, "msg": "数量必须大于 0"} if side not in ("buy", "sell"): return {"ok": False, "msg": "side 只能是 buy 或 sell"} exec_price = price or _get_price(code) if not exec_price: return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"} fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2) with get_session() as s: acc = s.get(PaperAccount, account_id) if not acc: return {"ok": False, "msg": "账户不存在"} sec = s.get(Security, code) name = sec.name if sec else code if side == "buy": cost = exec_price * qty + fee if acc.cash < cost: return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"} cash_before = acc.cash acc.cash -= cost else: holdings = _calc_holdings_in_session(account_id, s) pos = next((h for h in holdings if h["code"] == code), None) avail = pos["qty"] if pos else 0 if avail < qty: return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty} 股"} cash_before = acc.cash acc.cash += exec_price * qty - fee trade = PaperTrade( account_id=account_id, date=dt.date.today(), code=code, name=name, side=side, price=exec_price, qty=qty, fee=fee, cash_before=cash_before, cash_after=acc.cash, reason=reason, ) s.add(trade) s.commit() return {"ok": True, "id": trade.id, "price": exec_price, "fee": fee, "cash_after": round(acc.cash, 2)} # ── 查询接口 ────────────────────────────────────────────── def get_portfolio(account_id: int) -> dict: with get_session() as s: acc = s.get(PaperAccount, account_id) if not acc: return {"ok": False, "msg": "账户不存在"} cash = acc.cash initial = acc.initial_cash holdings_raw = _calc_holdings_in_session(account_id, s) codes = [h["code"] for h in holdings_raw] px: dict[str, float] = {} if codes: with get_session() as s: for m in s.execute( select(StockMetric).where(StockMetric.code.in_(codes)) ).scalars(): px[m.code] = m.close for c in [c for c in codes if c not in px]: row = s.execute( select(DailyQuote.close).where(DailyQuote.code == c) .order_by(DailyQuote.date.desc()).limit(1) ).scalar_one_or_none() if row: px[c] = float(row) holdings, mkt_val = [], 0.0 for h in holdings_raw: avg = h["cost"] / h["qty"] if h["qty"] else 0.0 cur = px.get(h["code"], avg) mv = cur * h["qty"] unreal = (cur - avg) * h["qty"] mkt_val += mv holdings.append({ "code": h["code"], "name": h["name"], "qty": h["qty"], "avg_cost": round(avg, 3), "cur": round(cur, 3), "market_value": round(mv, 2), "unrealized": round(unreal, 2), "unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0, }) holdings.sort(key=lambda x: x["unrealized"], reverse=True) total_assets = cash + mkt_val total_pnl = total_assets - initial return { "ok": True, "account_id": account_id, "summary": { "initial_cash": round(initial, 2), "cash": round(cash, 2), "market_value": round(mkt_val, 2), "total_assets": round(total_assets, 2), "total_pnl": round(total_pnl, 2), "total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0, "positions": len(holdings), }, "holdings": holdings, } def get_trades(account_id: int, limit: int = 100) -> list[dict]: with get_session() as s: rows = s.execute( select(PaperTrade).where(PaperTrade.account_id == account_id) .order_by(PaperTrade.id.desc()).limit(limit) ).scalars().all() return [{ "id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name, "side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee, "cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2), "reason": t.reason, } for t in rows]