222 lines
8.3 KiB
Python
222 lines
8.3 KiB
Python
"""模拟盘核心逻辑。
|
||
|
||
- 多账户支持(默认账户 id=1)
|
||
- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金
|
||
- 持仓计算使用移动加权平均成本法
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import datetime as dt
|
||
from collections import defaultdict
|
||
|
||
from sqlalchemy import select
|
||
|
||
from db import get_session
|
||
from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote
|
||
|
||
DEFAULT_FEE_RATE = 0.0003
|
||
|
||
|
||
def _get_price(code: str) -> float | None:
|
||
with get_session() as s:
|
||
m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none()
|
||
if m:
|
||
return float(m)
|
||
row = s.execute(
|
||
select(DailyQuote.close).where(DailyQuote.code == code)
|
||
.order_by(DailyQuote.date.desc()).limit(1)
|
||
).scalar_one_or_none()
|
||
return float(row) if row else None
|
||
|
||
|
||
# ── 账户管理 ──────────────────────────────────────────────
|
||
|
||
def ensure_default_account():
|
||
"""确保默认账户(id=1)存在,启动时调用。"""
|
||
with get_session() as s:
|
||
if not s.get(PaperAccount, 1):
|
||
s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0))
|
||
s.commit()
|
||
|
||
|
||
def list_accounts() -> list[dict]:
|
||
with get_session() as s:
|
||
rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all()
|
||
return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash,
|
||
"cash": round(r.cash, 2), "is_active": r.is_active,
|
||
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]
|
||
|
||
|
||
def create_account(name: str, initial_cash: float) -> dict:
|
||
with get_session() as s:
|
||
acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash)
|
||
s.add(acc)
|
||
s.commit()
|
||
return {"ok": True, "id": acc.id}
|
||
|
||
|
||
def reset_account(account_id: int, initial_cash: float | None = None) -> dict:
|
||
with get_session() as s:
|
||
acc = s.get(PaperAccount, account_id)
|
||
if not acc:
|
||
return {"ok": False, "msg": "账户不存在"}
|
||
if initial_cash is not None:
|
||
acc.initial_cash = initial_cash
|
||
acc.cash = acc.initial_cash
|
||
for t in s.execute(
|
||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||
).scalars():
|
||
s.delete(t)
|
||
s.commit()
|
||
return {"ok": True, "msg": "账户已重置"}
|
||
|
||
|
||
# ── 持仓计算(内部)────────────────────────────────────────
|
||
|
||
def _calc_holdings_in_session(account_id: int, s) -> list[dict]:
|
||
trades = s.execute(
|
||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||
.order_by(PaperTrade.date, PaperTrade.id)
|
||
).scalars().all()
|
||
pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
|
||
for t in trades:
|
||
p = pos[t.code]
|
||
p["name"] = t.name or p["name"]
|
||
if t.side == "buy":
|
||
p["cost"] += t.price * t.qty + t.fee
|
||
p["qty"] += t.qty
|
||
else:
|
||
if p["qty"] > 0:
|
||
avg = p["cost"] / p["qty"]
|
||
qty = min(t.qty, p["qty"])
|
||
p["cost"] -= avg * qty
|
||
p["qty"] -= qty
|
||
return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]}
|
||
for c, v in pos.items() if v["qty"] > 0]
|
||
|
||
|
||
# ── 下单 ──────────────────────────────────────────────────
|
||
|
||
def place_order(account_id: int, code: str, side: str, qty: int,
|
||
price: float | None = None, reason: str = "") -> dict:
|
||
if qty <= 0:
|
||
return {"ok": False, "msg": "数量必须大于 0"}
|
||
if side not in ("buy", "sell"):
|
||
return {"ok": False, "msg": "side 只能是 buy 或 sell"}
|
||
|
||
exec_price = price or _get_price(code)
|
||
if not exec_price:
|
||
return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"}
|
||
|
||
fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2)
|
||
|
||
with get_session() as s:
|
||
acc = s.get(PaperAccount, account_id)
|
||
if not acc:
|
||
return {"ok": False, "msg": "账户不存在"}
|
||
|
||
sec = s.get(Security, code)
|
||
name = sec.name if sec else code
|
||
|
||
if side == "buy":
|
||
cost = exec_price * qty + fee
|
||
if acc.cash < cost:
|
||
return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"}
|
||
cash_before = acc.cash
|
||
acc.cash -= cost
|
||
else:
|
||
holdings = _calc_holdings_in_session(account_id, s)
|
||
pos = next((h for h in holdings if h["code"] == code), None)
|
||
avail = pos["qty"] if pos else 0
|
||
if avail < qty:
|
||
return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty} 股"}
|
||
cash_before = acc.cash
|
||
acc.cash += exec_price * qty - fee
|
||
|
||
trade = PaperTrade(
|
||
account_id=account_id,
|
||
date=dt.date.today(),
|
||
code=code, name=name, side=side,
|
||
price=exec_price, qty=qty, fee=fee,
|
||
cash_before=cash_before, cash_after=acc.cash,
|
||
reason=reason,
|
||
)
|
||
s.add(trade)
|
||
s.commit()
|
||
return {"ok": True, "id": trade.id, "price": exec_price,
|
||
"fee": fee, "cash_after": round(acc.cash, 2)}
|
||
|
||
|
||
# ── 查询接口 ──────────────────────────────────────────────
|
||
|
||
def get_portfolio(account_id: int) -> dict:
|
||
with get_session() as s:
|
||
acc = s.get(PaperAccount, account_id)
|
||
if not acc:
|
||
return {"ok": False, "msg": "账户不存在"}
|
||
cash = acc.cash
|
||
initial = acc.initial_cash
|
||
holdings_raw = _calc_holdings_in_session(account_id, s)
|
||
|
||
codes = [h["code"] for h in holdings_raw]
|
||
px: dict[str, float] = {}
|
||
if codes:
|
||
with get_session() as s:
|
||
for m in s.execute(
|
||
select(StockMetric).where(StockMetric.code.in_(codes))
|
||
).scalars():
|
||
px[m.code] = m.close
|
||
for c in [c for c in codes if c not in px]:
|
||
row = s.execute(
|
||
select(DailyQuote.close).where(DailyQuote.code == c)
|
||
.order_by(DailyQuote.date.desc()).limit(1)
|
||
).scalar_one_or_none()
|
||
if row:
|
||
px[c] = float(row)
|
||
|
||
holdings, mkt_val = [], 0.0
|
||
for h in holdings_raw:
|
||
avg = h["cost"] / h["qty"] if h["qty"] else 0.0
|
||
cur = px.get(h["code"], avg)
|
||
mv = cur * h["qty"]
|
||
unreal = (cur - avg) * h["qty"]
|
||
mkt_val += mv
|
||
holdings.append({
|
||
"code": h["code"], "name": h["name"], "qty": h["qty"],
|
||
"avg_cost": round(avg, 3), "cur": round(cur, 3),
|
||
"market_value": round(mv, 2),
|
||
"unrealized": round(unreal, 2),
|
||
"unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0,
|
||
})
|
||
holdings.sort(key=lambda x: x["unrealized"], reverse=True)
|
||
|
||
total_assets = cash + mkt_val
|
||
total_pnl = total_assets - initial
|
||
return {
|
||
"ok": True,
|
||
"account_id": account_id,
|
||
"summary": {
|
||
"initial_cash": round(initial, 2),
|
||
"cash": round(cash, 2),
|
||
"market_value": round(mkt_val, 2),
|
||
"total_assets": round(total_assets, 2),
|
||
"total_pnl": round(total_pnl, 2),
|
||
"total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0,
|
||
"positions": len(holdings),
|
||
},
|
||
"holdings": holdings,
|
||
}
|
||
|
||
|
||
def get_trades(account_id: int, limit: int = 100) -> list[dict]:
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||
.order_by(PaperTrade.id.desc()).limit(limit)
|
||
).scalars().all()
|
||
return [{
|
||
"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
|
||
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
|
||
"cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2),
|
||
"reason": t.reason,
|
||
} for t in rows] |