"""增强版回测引擎 — 多因子策略、仓位管理、参数优化。 支持功能: 1. 多因子组合策略(技术+基本面) 2. 仓位管理(固定、金字塔、凯利公式) 3. 止损止盈 4. 参数网格优化 5. 完整指标(夏普比率、最大回撤、卡玛比率等) 6. 交易明细导出 """ import datetime as dt from typing import Dict, List, Any, Optional, Callable import numpy as np from sqlalchemy import select from db import get_session from models import DailyQuote, StockMetric class Position: """持仓记录""" def __init__(self, date, price, shares, reason=""): self.entry_date = date self.entry_price = price self.shares = shares self.reason = reason self.exit_date = None self.exit_price = None self.pnl = 0.0 self.pnl_pct = 0.0 self.hold_days = 0 class BacktestEngine: """增强回测引擎""" def __init__(self, initial_capital: float = 100000.0, commission: float = 0.0005): self.initial_capital = initial_capital self.commission = commission # 账户状态 self.cash = initial_capital self.positions: List[Position] = [] self.closed_positions: List[Position] = [] # 净值曲线 self.equity_curve = [] self.dates = [] # 统计 self.trades = 0 self.wins = 0 self.total_pnl = 0.0 def get_position_value(self, price: float) -> float: """计算持仓市值""" return sum(p.shares * price for p in self.positions) def get_total_value(self, price: float) -> float: """计算总资产""" return self.cash + self.get_position_value(price) def buy(self, date, price: float, size: float, reason: str = ""): """买入 Args: date: 交易日期 price: 买入价格 size: 仓位大小(0-1),相对于当前可用资金 reason: 买入理由 """ if size <= 0 or size > 1: return False cost = self.cash * size commission_fee = cost * self.commission net_cost = cost - commission_fee if net_cost <= 0: return False shares = net_cost / price self.cash -= cost pos = Position(date, price, shares, reason) self.positions.append(pos) self.trades += 1 return True def sell(self, date, price: float, size: float = 1.0, reason: str = ""): """卖出 Args: date: 交易日期 price: 卖出价格 size: 卖出比例(0-1),相对于持仓 reason: 卖出理由 """ if not self.positions or size <= 0 or size > 1: return False # 按先进先出卖出 remaining = size sold_positions = [] for pos in self.positions[:]: if remaining <= 0: break sell_ratio = min(remaining, 1.0) sell_shares = pos.shares * sell_ratio proceeds = sell_shares * price commission_fee = proceeds * self.commission net_proceeds = proceeds - commission_fee self.cash += net_proceeds # 更新持仓 pos.shares -= sell_shares if pos.shares < 0.01: # 清仓 pos.exit_date = date pos.exit_price = price pos.hold_days = (date - pos.entry_date).days pos.pnl = (price - pos.entry_price) * (sell_shares / sell_ratio) pos.pnl_pct = (price / pos.entry_price - 1) * 100 self.closed_positions.append(pos) self.positions.remove(pos) if pos.pnl > 0: self.wins += 1 self.total_pnl += pos.pnl remaining -= sell_ratio sold_positions.append((pos, sell_shares)) return True def record_state(self, date, price: float): """记录当前状态""" self.dates.append(date) self.equity_curve.append(self.get_total_value(price)) def get_metrics(self) -> Dict[str, Any]: """计算完整指标""" if not self.equity_curve: return {} equity = np.array(self.equity_curve) returns = np.diff(equity) / equity[:-1] # 基础指标 total_return = (equity[-1] / equity[0] - 1) * 100 # 最大回撤 peak = np.maximum.accumulate(equity) drawdown = (peak - equity) / peak max_drawdown = np.max(drawdown) * 100 # 夏普比率(年化,假设252个交易日) if len(returns) > 1 and np.std(returns) > 0: sharpe = np.mean(returns) / np.std(returns) * np.sqrt(252) else: sharpe = 0.0 # 卡玛比率(收益/最大回撤) calmar = total_return / max_drawdown if max_drawdown > 0 else 0.0 # 胜率 closed = len(self.closed_positions) win_rate = (self.wins / closed * 100) if closed > 0 else 0.0 # 盈亏比 winning_trades = [p.pnl for p in self.closed_positions if p.pnl > 0] losing_trades = [abs(p.pnl) for p in self.closed_positions if p.pnl < 0] avg_win = np.mean(winning_trades) if winning_trades else 0.0 avg_loss = np.mean(losing_trades) if losing_trades else 0.0 profit_factor = avg_win / avg_loss if avg_loss > 0 else 0.0 # 持仓天数 hold_days = [p.hold_days for p in self.closed_positions] avg_hold = np.mean(hold_days) if hold_days else 0.0 return { "total_return": round(total_return, 2), "max_drawdown": round(max_drawdown, 2), "sharpe_ratio": round(sharpe, 3), "calmar_ratio": round(calmar, 3), "trades": self.trades, "closed_trades": closed, "win_rate": round(win_rate, 1), "profit_factor": round(profit_factor, 2), "avg_win": round(avg_win, 2), "avg_loss": round(avg_loss, 2), "avg_hold_days": round(avg_hold, 1), "total_pnl": round(self.total_pnl, 2), } class Strategy: """策略基类""" def __init__(self, name: str): self.name = name def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None: """每日回调""" raise NotImplementedError class MAStrategy(Strategy): """均线交叉策略(增强版)""" def __init__(self, fast: int = 5, slow: int = 20, position_size: float = 1.0, stop_loss: float = 0.0, take_profit: float = 0.0): super().__init__(f"MA{fast}/{slow}") self.fast = fast self.slow = slow self.position_size = position_size self.stop_loss = stop_loss # 止损比例 self.take_profit = take_profit # 止盈比例 self.ma_fast_history = [] self.ma_slow_history = [] self.close_history = [] def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None: close = data["close"] self.close_history.append(close) # 计算均线 if len(self.close_history) >= self.fast: self.ma_fast_history.append(np.mean(self.close_history[-self.fast:])) else: self.ma_fast_history.append(None) if len(self.close_history) >= self.slow: self.ma_slow_history.append(np.mean(self.close_history[-self.slow:])) else: self.ma_slow_history.append(None) if len(self.ma_fast_history) < 2: engine.record_state(date, close) return maf_curr = self.ma_fast_history[-1] maf_prev = self.ma_fast_history[-2] mas_curr = self.ma_slow_history[-1] mas_prev = self.ma_slow_history[-2] if maf_curr is None or mas_curr is None: engine.record_state(date, close) return # 止损止盈检查 if engine.positions: for pos in engine.positions[:]: pnl_pct = (close / pos.entry_price - 1) * 100 # 止损 if self.stop_loss > 0 and pnl_pct <= -self.stop_loss: engine.sell(date, close, 1.0, f"止损 {pnl_pct:.2f}%") # 止盈 elif self.take_profit > 0 and pnl_pct >= self.take_profit: engine.sell(date, close, 1.0, f"止盈 {pnl_pct:.2f}%") # 金叉买入 if maf_prev <= mas_prev and maf_curr > mas_curr: if not engine.positions: engine.buy(date, close, self.position_size, "金叉") # 死叉卖出 elif maf_prev >= mas_prev and maf_curr < mas_curr: if engine.positions: engine.sell(date, close, 1.0, "死叉") engine.record_state(date, close) class MultiFactorStrategy(Strategy): """多因子策略""" def __init__(self, position_size: float = 1.0): super().__init__("多因子") self.position_size = position_size self.close_history = [] self.volume_history = [] def calculate_rsi(self, n: int = 14) -> Optional[float]: """计算RSI""" if len(self.close_history) < n + 1: return None changes = np.diff(self.close_history[-n-1:]) gains = np.where(changes > 0, changes, 0) losses = np.where(changes < 0, -changes, 0) avg_gain = np.mean(gains) avg_loss = np.mean(losses) if avg_loss == 0: return 100.0 rs = avg_gain / avg_loss rsi = 100 - (100 / (1 + rs)) return rsi def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None: close = data["close"] volume = data.get("volume", 0) self.close_history.append(close) self.volume_history.append(volume) if len(self.close_history) < 30: engine.record_state(date, close) return # 计算因子 ma5 = np.mean(self.close_history[-5:]) ma20 = np.mean(self.close_history[-20:]) rsi = self.calculate_rsi(14) # 量比 vol_avg = np.mean(self.volume_history[-20:-1]) vol_ratio = volume / vol_avg if vol_avg > 0 else 1.0 # 买入信号:MA5 > MA20, RSI < 70, 放量 buy_signal = (ma5 > ma20 and rsi is not None and rsi < 70 and vol_ratio > 1.5) # 卖出信号:MA5 < MA20 或 RSI > 80 sell_signal = (ma5 < ma20 or (rsi is not None and rsi > 80)) if buy_signal and not engine.positions: engine.buy(date, close, self.position_size, "多因子买入") if sell_signal and engine.positions: engine.sell(date, close, 1.0, "多因子卖出") engine.record_state(date, close) def run_advanced_backtest(symbol: str, strategy: Strategy, start_date: Optional[str] = None, end_date: Optional[str] = None, initial_capital: float = 100000.0, commission: float = 0.0005) -> Dict[str, Any]: """运行增强回测 Args: symbol: 股票代码 strategy: 策略实例 start_date: 开始日期 end_date: 结束日期 initial_capital: 初始资金 commission: 手续费率 Returns: 回测结果 """ with get_session() as s: query = select(DailyQuote.date, DailyQuote.close, DailyQuote.volume).where( DailyQuote.code == symbol ) if start_date: query = query.where(DailyQuote.date >= dt.date.fromisoformat(start_date)) if end_date: query = query.where(DailyQuote.date <= dt.date.fromisoformat(end_date)) query = query.order_by(DailyQuote.date) rows = s.execute(query).all() if not rows: return {"ok": False, "msg": "无数据"} engine = BacktestEngine(initial_capital, commission) # 逐日回测 for row in rows: date, close, volume = row data = {"close": float(close), "volume": int(volume)} strategy.on_data(engine, date, data) # 计算基准(买入持有) bench_curve = [] first_close = rows[0][1] for row in rows: bench_curve.append(float(row[1]) / float(first_close) * initial_capital) metrics = engine.get_metrics() # 交易明细 trades_detail = [{ "entry_date": p.entry_date.isoformat(), "exit_date": p.exit_date.isoformat() if p.exit_date else "", "entry_price": round(p.entry_price, 2), "exit_price": round(p.exit_price, 2) if p.exit_price else 0, "shares": round(p.shares, 2), "hold_days": p.hold_days, "pnl": round(p.pnl, 2), "pnl_pct": round(p.pnl_pct, 2), "reason": p.reason } for p in engine.closed_positions] return { "ok": True, "symbol": symbol, "strategy": strategy.name, "dates": [d.isoformat() for d in engine.dates], "equity": [round(e, 2) for e in engine.equity_curve], "bench": [round(b, 2) for b in bench_curve], "metrics": metrics, "trades": trades_detail, "initial_capital": initial_capital, } def optimize_parameters(symbol: str, param_grid: Dict[str, List], strategy_class: type, metric: str = "sharpe_ratio") -> List[Dict[str, Any]]: """参数网格优化 Args: symbol: 股票代码 param_grid: 参数网格,如 {"fast": [3,5,10], "slow": [10,20,30]} strategy_class: 策略类 metric: 优化目标指标 Returns: 优化结果列表,按指标降序排列 """ import itertools keys = list(param_grid.keys()) values = list(param_grid.values()) results = [] # 遍历所有参数组合 for combo in itertools.product(*values): params = dict(zip(keys, combo)) try: strategy = strategy_class(**params) result = run_advanced_backtest(symbol, strategy) if result["ok"]: results.append({ "params": params, "metrics": result["metrics"], metric: result["metrics"].get(metric, 0) }) except Exception as e: print(f"优化失败 {params}: {e}") continue # 按目标指标排序 results.sort(key=lambda x: x[metric], reverse=True) return results def compare_strategies(symbol: str, strategies: List[Strategy], initial_capital: float = 100000.0) -> Dict[str, Any]: """策略对比 Args: symbol: 股票代码 strategies: 策略列表 initial_capital: 初始资金 Returns: 对比结果 """ results = [] for strategy in strategies: result = run_advanced_backtest(symbol, strategy, initial_capital=initial_capital) if result["ok"]: results.append({ "strategy": strategy.name, "equity": result["equity"], "metrics": result["metrics"] }) return { "ok": True, "symbol": symbol, "dates": result["dates"] if results else [], "strategies": results }