Files
stock_cursor_v0/backend/position_cost.py
2026-06-15 01:26:39 +08:00

348 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""持仓成本可视化增强"""
import datetime as dt
from typing import Dict, List, Optional
from collections import defaultdict
from sqlalchemy import select, func
from db import get_session
from models import Trade, DailyQuote, StockMetric
# A股交易成本配置
COST_CONFIG = {
"stamp_tax": 0.001, # 印花税 0.1%(仅卖出)
"commission_rate": 0.0003, # 佣金费率 0.03%
"commission_min": 5.0, # 最低佣金 5元
"transfer_fee": 0.00001, # 过户费 0.001%(沪市)
}
def calculate_trade_cost(price: float, qty: int, side: str, is_sh: bool = True) -> Dict:
"""精确计算交易成本
Args:
price: 成交价格
qty: 成交数量
side: buy/sell
is_sh: 是否沪市(影响过户费)
Returns:
成本明细字典
"""
amount = price * qty
# 佣金(买卖都有)
commission = max(amount * COST_CONFIG["commission_rate"], COST_CONFIG["commission_min"])
# 印花税(仅卖出)
stamp_tax = amount * COST_CONFIG["stamp_tax"] if side == "sell" else 0.0
# 过户费(沪市买卖都有,深市无)
transfer_fee = amount * COST_CONFIG["transfer_fee"] if is_sh else 0.0
total_cost = commission + stamp_tax + transfer_fee
return {
"amount": round(amount, 2),
"commission": round(commission, 2),
"stamp_tax": round(stamp_tax, 2),
"transfer_fee": round(transfer_fee, 2),
"total_cost": round(total_cost, 2),
"cost_rate": round(total_cost / amount * 100, 4) if amount > 0 else 0.0
}
def get_position_cost_lines(code: str) -> Dict:
"""获取个股的持仓成本线数据用于K线图标注
Returns:
{
"code": "600519",
"name": "贵州茅台",
"current_position": {
"qty": 100,
"avg_cost": 1680.5,
"total_cost": 168050.0,
"trades_count": 3
},
"cost_history": [
{"date": "2024-01-15", "cost": 1650.0, "qty": 100, "action": "买入"},
{"date": "2024-02-10", "cost": 1680.5, "qty": 100, "action": "补仓"}
]
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).where(Trade.code == code)
.order_by(Trade.date, Trade.id)
).scalars().all()
if not trades:
return {"ok": False, "msg": "该股票无交易记录"}
# 计算持仓成本变化
qty = 0
cost = 0.0
cost_history = []
for t in trades:
is_sh = t.code.startswith("6")
if t.side == "buy":
# 买入:加权平均成本
old_qty = qty
old_cost = cost
qty += t.qty
cost += t.price * t.qty + t.fee
avg_cost = cost / qty if qty > 0 else 0
action = "补仓" if old_qty > 0 else "买入"
cost_history.append({
"date": t.date.isoformat(),
"cost": round(avg_cost, 2),
"qty": qty,
"action": action,
"trade_price": t.price,
"trade_qty": t.qty
})
else: # sell
if qty <= 0:
continue
avg_cost = cost / qty
sell_qty = min(t.qty, qty)
# 卖出:减少持仓
cost -= avg_cost * sell_qty
qty -= sell_qty
action = "清仓" if qty == 0 else "减仓"
cost_history.append({
"date": t.date.isoformat(),
"cost": round(cost / qty, 2) if qty > 0 else 0,
"qty": qty,
"action": action,
"trade_price": t.price,
"trade_qty": sell_qty,
"pnl": round((t.price - avg_cost) * sell_qty - t.fee, 2)
})
# 当前持仓
current_position = None
if qty > 0:
avg_cost = cost / qty
# 获取当前价格
metric = s.execute(
select(StockMetric).where(StockMetric.code == code)
).scalar_one_or_none()
current_price = metric.close if metric else avg_cost
current_position = {
"qty": qty,
"avg_cost": round(avg_cost, 2),
"total_cost": round(cost, 2),
"current_price": round(current_price, 2),
"market_value": round(current_price * qty, 2),
"unrealized_pnl": round((current_price - avg_cost) * qty, 2),
"unrealized_pct": round((current_price / avg_cost - 1) * 100, 2) if avg_cost > 0 else 0,
"trades_count": len([t for t in trades if t.side == "buy"])
}
return {
"ok": True,
"code": code,
"name": trades[0].name,
"current_position": current_position,
"cost_history": cost_history
}
def get_position_cost_distribution() -> Dict:
"""获取所有持仓的成本分布(盈亏区间图)
Returns:
{
"profitable": [...], # 盈利持仓
"unprofitable": [...], # 亏损持仓
"breakeven": [...] # 持平持仓
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).order_by(Trade.date, Trade.id)
).scalars().all()
# 计算当前持仓
pos = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
for t in trades:
p = pos[t.code]
p["name"] = t.name or p["name"]
if t.side == "buy":
p["cost"] += t.price * t.qty + t.fee
p["qty"] += t.qty
else:
if p["qty"] > 0:
avg = p["cost"] / p["qty"]
qty = min(t.qty, p["qty"])
p["cost"] -= avg * qty
p["qty"] -= qty
# 获取当前价格
codes = [c for c, v in pos.items() if v["qty"] > 0]
if not codes:
return {"ok": True, "profitable": [], "unprofitable": [], "breakeven": []}
metrics = s.execute(
select(StockMetric).where(StockMetric.code.in_(codes))
).scalars().all()
price_map = {m.code: m.close for m in metrics}
# 分类统计
profitable = []
unprofitable = []
breakeven = []
for code, p in pos.items():
if p["qty"] <= 0:
continue
avg_cost = p["cost"] / p["qty"]
current_price = price_map.get(code, avg_cost)
unrealized = (current_price - avg_cost) * p["qty"]
unrealized_pct = (current_price / avg_cost - 1) * 100 if avg_cost > 0 else 0
item = {
"code": code,
"name": p["name"],
"qty": p["qty"],
"avg_cost": round(avg_cost, 2),
"current_price": round(current_price, 2),
"market_value": round(current_price * p["qty"], 2),
"cost_value": round(p["cost"], 2),
"unrealized": round(unrealized, 2),
"unrealized_pct": round(unrealized_pct, 2)
}
if unrealized_pct > 0.5:
profitable.append(item)
elif unrealized_pct < -0.5:
unprofitable.append(item)
else:
breakeven.append(item)
# 排序
profitable.sort(key=lambda x: x["unrealized"], reverse=True)
unprofitable.sort(key=lambda x: x["unrealized"])
return {
"ok": True,
"profitable": profitable,
"unprofitable": unprofitable,
"breakeven": breakeven,
"summary": {
"total_positions": len(codes),
"profitable_count": len(profitable),
"unprofitable_count": len(unprofitable),
"breakeven_count": len(breakeven),
"win_rate": round(len(profitable) / len(codes) * 100, 1) if codes else 0
}
}
def estimate_trade_cost(code: str, price: float, qty: int, side: str) -> Dict:
"""估算交易成本(下单前预估)
Args:
code: 股票代码
price: 预计成交价
qty: 交易数量
side: buy/sell
Returns:
成本明细和净值
"""
is_sh = code.startswith("6")
cost_detail = calculate_trade_cost(price, qty, side, is_sh)
if side == "buy":
net_amount = cost_detail["amount"] + cost_detail["total_cost"]
msg = f"买入需支付: {round(net_amount, 2)} 元(含交易成本 {cost_detail['total_cost']} 元)"
else:
net_amount = cost_detail["amount"] - cost_detail["total_cost"]
msg = f"卖出可获得: {round(net_amount, 2)} 元(扣除交易成本 {cost_detail['total_cost']} 元)"
return {
"ok": True,
"code": code,
"price": price,
"qty": qty,
"side": side,
"cost_detail": cost_detail,
"net_amount": round(net_amount, 2),
"message": msg
}
def get_cost_breakdown_for_position(code: str) -> Dict:
"""获取持仓的详细成本拆解
Returns:
{
"total_cost": 168500.0,
"purchase_amount": 168050.0, # 实际买入金额
"commission": 350.0, # 累计佣金
"stamp_tax": 0.0, # 累计印花税(买入无)
"transfer_fee": 100.0, # 累计过户费
"trades": [...] # 每笔交易明细
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).where(Trade.code == code, Trade.side == "buy")
.order_by(Trade.date)
).scalars().all()
if not trades:
return {"ok": False, "msg": "该股票无买入记录"}
is_sh = code.startswith("6")
total_purchase = 0.0
total_commission = 0.0
total_stamp = 0.0
total_transfer = 0.0
trade_details = []
for t in trades:
cost = calculate_trade_cost(t.price, t.qty, "buy", is_sh)
total_purchase += cost["amount"]
total_commission += cost["commission"]
total_stamp += cost["stamp_tax"]
total_transfer += cost["transfer_fee"]
trade_details.append({
"date": t.date.isoformat(),
"price": t.price,
"qty": t.qty,
"amount": cost["amount"],
"cost_detail": cost
})
total_cost = total_purchase + total_commission + total_stamp + total_transfer
return {
"ok": True,
"code": code,
"name": trades[0].name,
"total_cost": round(total_cost, 2),
"purchase_amount": round(total_purchase, 2),
"commission": round(total_commission, 2),
"stamp_tax": round(total_stamp, 2),
"transfer_fee": round(total_transfer, 2),
"cost_rate": round((total_cost - total_purchase) / total_purchase * 100, 4) if total_purchase > 0 else 0,
"trades": trade_details
}