500 lines
16 KiB
Python
500 lines
16 KiB
Python
"""增强版回测引擎 — 多因子策略、仓位管理、参数优化。
|
||
|
||
支持功能:
|
||
1. 多因子组合策略(技术+基本面)
|
||
2. 仓位管理(固定、金字塔、凯利公式)
|
||
3. 止损止盈
|
||
4. 参数网格优化
|
||
5. 完整指标(夏普比率、最大回撤、卡玛比率等)
|
||
6. 交易明细导出
|
||
"""
|
||
import datetime as dt
|
||
from typing import Dict, List, Any, Optional, Callable
|
||
import numpy as np
|
||
from sqlalchemy import select
|
||
|
||
from db import get_session
|
||
from models import DailyQuote, StockMetric
|
||
|
||
|
||
class Position:
|
||
"""持仓记录"""
|
||
def __init__(self, date, price, shares, reason=""):
|
||
self.entry_date = date
|
||
self.entry_price = price
|
||
self.shares = shares
|
||
self.reason = reason
|
||
self.exit_date = None
|
||
self.exit_price = None
|
||
self.pnl = 0.0
|
||
self.pnl_pct = 0.0
|
||
self.hold_days = 0
|
||
|
||
|
||
class BacktestEngine:
|
||
"""增强回测引擎"""
|
||
|
||
def __init__(self, initial_capital: float = 100000.0, commission: float = 0.0005):
|
||
self.initial_capital = initial_capital
|
||
self.commission = commission
|
||
|
||
# 账户状态
|
||
self.cash = initial_capital
|
||
self.positions: List[Position] = []
|
||
self.closed_positions: List[Position] = []
|
||
|
||
# 净值曲线
|
||
self.equity_curve = []
|
||
self.dates = []
|
||
|
||
# 统计
|
||
self.trades = 0
|
||
self.wins = 0
|
||
self.total_pnl = 0.0
|
||
|
||
def get_position_value(self, price: float) -> float:
|
||
"""计算持仓市值"""
|
||
return sum(p.shares * price for p in self.positions)
|
||
|
||
def get_total_value(self, price: float) -> float:
|
||
"""计算总资产"""
|
||
return self.cash + self.get_position_value(price)
|
||
|
||
def buy(self, date, price: float, size: float, reason: str = ""):
|
||
"""买入
|
||
|
||
Args:
|
||
date: 交易日期
|
||
price: 买入价格
|
||
size: 仓位大小(0-1),相对于当前可用资金
|
||
reason: 买入理由
|
||
"""
|
||
if size <= 0 or size > 1:
|
||
return False
|
||
|
||
cost = self.cash * size
|
||
commission_fee = cost * self.commission
|
||
net_cost = cost - commission_fee
|
||
|
||
if net_cost <= 0:
|
||
return False
|
||
|
||
shares = net_cost / price
|
||
self.cash -= cost
|
||
|
||
pos = Position(date, price, shares, reason)
|
||
self.positions.append(pos)
|
||
self.trades += 1
|
||
return True
|
||
|
||
def sell(self, date, price: float, size: float = 1.0, reason: str = ""):
|
||
"""卖出
|
||
|
||
Args:
|
||
date: 交易日期
|
||
price: 卖出价格
|
||
size: 卖出比例(0-1),相对于持仓
|
||
reason: 卖出理由
|
||
"""
|
||
if not self.positions or size <= 0 or size > 1:
|
||
return False
|
||
|
||
# 按先进先出卖出
|
||
remaining = size
|
||
sold_positions = []
|
||
|
||
for pos in self.positions[:]:
|
||
if remaining <= 0:
|
||
break
|
||
|
||
sell_ratio = min(remaining, 1.0)
|
||
sell_shares = pos.shares * sell_ratio
|
||
proceeds = sell_shares * price
|
||
commission_fee = proceeds * self.commission
|
||
net_proceeds = proceeds - commission_fee
|
||
|
||
self.cash += net_proceeds
|
||
|
||
# 更新持仓
|
||
pos.shares -= sell_shares
|
||
if pos.shares < 0.01: # 清仓
|
||
pos.exit_date = date
|
||
pos.exit_price = price
|
||
pos.hold_days = (date - pos.entry_date).days
|
||
pos.pnl = (price - pos.entry_price) * (sell_shares / sell_ratio)
|
||
pos.pnl_pct = (price / pos.entry_price - 1) * 100
|
||
|
||
self.closed_positions.append(pos)
|
||
self.positions.remove(pos)
|
||
|
||
if pos.pnl > 0:
|
||
self.wins += 1
|
||
self.total_pnl += pos.pnl
|
||
|
||
remaining -= sell_ratio
|
||
sold_positions.append((pos, sell_shares))
|
||
|
||
return True
|
||
|
||
def record_state(self, date, price: float):
|
||
"""记录当前状态"""
|
||
self.dates.append(date)
|
||
self.equity_curve.append(self.get_total_value(price))
|
||
|
||
def get_metrics(self) -> Dict[str, Any]:
|
||
"""计算完整指标"""
|
||
if not self.equity_curve:
|
||
return {}
|
||
|
||
equity = np.array(self.equity_curve)
|
||
returns = np.diff(equity) / equity[:-1]
|
||
|
||
# 基础指标
|
||
total_return = (equity[-1] / equity[0] - 1) * 100
|
||
|
||
# 最大回撤
|
||
peak = np.maximum.accumulate(equity)
|
||
drawdown = (peak - equity) / peak
|
||
max_drawdown = np.max(drawdown) * 100
|
||
|
||
# 夏普比率(年化,假设252个交易日)
|
||
if len(returns) > 1 and np.std(returns) > 0:
|
||
sharpe = np.mean(returns) / np.std(returns) * np.sqrt(252)
|
||
else:
|
||
sharpe = 0.0
|
||
|
||
# 卡玛比率(收益/最大回撤)
|
||
calmar = total_return / max_drawdown if max_drawdown > 0 else 0.0
|
||
|
||
# 胜率
|
||
closed = len(self.closed_positions)
|
||
win_rate = (self.wins / closed * 100) if closed > 0 else 0.0
|
||
|
||
# 盈亏比
|
||
winning_trades = [p.pnl for p in self.closed_positions if p.pnl > 0]
|
||
losing_trades = [abs(p.pnl) for p in self.closed_positions if p.pnl < 0]
|
||
avg_win = np.mean(winning_trades) if winning_trades else 0.0
|
||
avg_loss = np.mean(losing_trades) if losing_trades else 0.0
|
||
profit_factor = avg_win / avg_loss if avg_loss > 0 else 0.0
|
||
|
||
# 持仓天数
|
||
hold_days = [p.hold_days for p in self.closed_positions]
|
||
avg_hold = np.mean(hold_days) if hold_days else 0.0
|
||
|
||
return {
|
||
"total_return": round(total_return, 2),
|
||
"max_drawdown": round(max_drawdown, 2),
|
||
"sharpe_ratio": round(sharpe, 3),
|
||
"calmar_ratio": round(calmar, 3),
|
||
"trades": self.trades,
|
||
"closed_trades": closed,
|
||
"win_rate": round(win_rate, 1),
|
||
"profit_factor": round(profit_factor, 2),
|
||
"avg_win": round(avg_win, 2),
|
||
"avg_loss": round(avg_loss, 2),
|
||
"avg_hold_days": round(avg_hold, 1),
|
||
"total_pnl": round(self.total_pnl, 2),
|
||
}
|
||
|
||
|
||
class Strategy:
|
||
"""策略基类"""
|
||
|
||
def __init__(self, name: str):
|
||
self.name = name
|
||
|
||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||
"""每日回调"""
|
||
raise NotImplementedError
|
||
|
||
|
||
class MAStrategy(Strategy):
|
||
"""均线交叉策略(增强版)"""
|
||
|
||
def __init__(self, fast: int = 5, slow: int = 20,
|
||
position_size: float = 1.0,
|
||
stop_loss: float = 0.0,
|
||
take_profit: float = 0.0):
|
||
super().__init__(f"MA{fast}/{slow}")
|
||
self.fast = fast
|
||
self.slow = slow
|
||
self.position_size = position_size
|
||
self.stop_loss = stop_loss # 止损比例
|
||
self.take_profit = take_profit # 止盈比例
|
||
|
||
self.ma_fast_history = []
|
||
self.ma_slow_history = []
|
||
self.close_history = []
|
||
|
||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||
close = data["close"]
|
||
self.close_history.append(close)
|
||
|
||
# 计算均线
|
||
if len(self.close_history) >= self.fast:
|
||
self.ma_fast_history.append(np.mean(self.close_history[-self.fast:]))
|
||
else:
|
||
self.ma_fast_history.append(None)
|
||
|
||
if len(self.close_history) >= self.slow:
|
||
self.ma_slow_history.append(np.mean(self.close_history[-self.slow:]))
|
||
else:
|
||
self.ma_slow_history.append(None)
|
||
|
||
if len(self.ma_fast_history) < 2:
|
||
engine.record_state(date, close)
|
||
return
|
||
|
||
maf_curr = self.ma_fast_history[-1]
|
||
maf_prev = self.ma_fast_history[-2]
|
||
mas_curr = self.ma_slow_history[-1]
|
||
mas_prev = self.ma_slow_history[-2]
|
||
|
||
if maf_curr is None or mas_curr is None:
|
||
engine.record_state(date, close)
|
||
return
|
||
|
||
# 止损止盈检查
|
||
if engine.positions:
|
||
for pos in engine.positions[:]:
|
||
pnl_pct = (close / pos.entry_price - 1) * 100
|
||
|
||
# 止损
|
||
if self.stop_loss > 0 and pnl_pct <= -self.stop_loss:
|
||
engine.sell(date, close, 1.0, f"止损 {pnl_pct:.2f}%")
|
||
# 止盈
|
||
elif self.take_profit > 0 and pnl_pct >= self.take_profit:
|
||
engine.sell(date, close, 1.0, f"止盈 {pnl_pct:.2f}%")
|
||
|
||
# 金叉买入
|
||
if maf_prev <= mas_prev and maf_curr > mas_curr:
|
||
if not engine.positions:
|
||
engine.buy(date, close, self.position_size, "金叉")
|
||
|
||
# 死叉卖出
|
||
elif maf_prev >= mas_prev and maf_curr < mas_curr:
|
||
if engine.positions:
|
||
engine.sell(date, close, 1.0, "死叉")
|
||
|
||
engine.record_state(date, close)
|
||
|
||
|
||
class MultiFactorStrategy(Strategy):
|
||
"""多因子策略"""
|
||
|
||
def __init__(self, position_size: float = 1.0):
|
||
super().__init__("多因子")
|
||
self.position_size = position_size
|
||
self.close_history = []
|
||
self.volume_history = []
|
||
|
||
def calculate_rsi(self, n: int = 14) -> Optional[float]:
|
||
"""计算RSI"""
|
||
if len(self.close_history) < n + 1:
|
||
return None
|
||
|
||
changes = np.diff(self.close_history[-n-1:])
|
||
gains = np.where(changes > 0, changes, 0)
|
||
losses = np.where(changes < 0, -changes, 0)
|
||
|
||
avg_gain = np.mean(gains)
|
||
avg_loss = np.mean(losses)
|
||
|
||
if avg_loss == 0:
|
||
return 100.0
|
||
|
||
rs = avg_gain / avg_loss
|
||
rsi = 100 - (100 / (1 + rs))
|
||
return rsi
|
||
|
||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||
close = data["close"]
|
||
volume = data.get("volume", 0)
|
||
|
||
self.close_history.append(close)
|
||
self.volume_history.append(volume)
|
||
|
||
if len(self.close_history) < 30:
|
||
engine.record_state(date, close)
|
||
return
|
||
|
||
# 计算因子
|
||
ma5 = np.mean(self.close_history[-5:])
|
||
ma20 = np.mean(self.close_history[-20:])
|
||
rsi = self.calculate_rsi(14)
|
||
|
||
# 量比
|
||
vol_avg = np.mean(self.volume_history[-20:-1])
|
||
vol_ratio = volume / vol_avg if vol_avg > 0 else 1.0
|
||
|
||
# 买入信号:MA5 > MA20, RSI < 70, 放量
|
||
buy_signal = (ma5 > ma20 and
|
||
rsi is not None and rsi < 70 and
|
||
vol_ratio > 1.5)
|
||
|
||
# 卖出信号:MA5 < MA20 或 RSI > 80
|
||
sell_signal = (ma5 < ma20 or
|
||
(rsi is not None and rsi > 80))
|
||
|
||
if buy_signal and not engine.positions:
|
||
engine.buy(date, close, self.position_size, "多因子买入")
|
||
|
||
if sell_signal and engine.positions:
|
||
engine.sell(date, close, 1.0, "多因子卖出")
|
||
|
||
engine.record_state(date, close)
|
||
|
||
|
||
def run_advanced_backtest(symbol: str,
|
||
strategy: Strategy,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
initial_capital: float = 100000.0,
|
||
commission: float = 0.0005) -> Dict[str, Any]:
|
||
"""运行增强回测
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
strategy: 策略实例
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
initial_capital: 初始资金
|
||
commission: 手续费率
|
||
|
||
Returns:
|
||
回测结果
|
||
"""
|
||
with get_session() as s:
|
||
query = select(DailyQuote.date, DailyQuote.close, DailyQuote.volume).where(
|
||
DailyQuote.code == symbol
|
||
)
|
||
|
||
if start_date:
|
||
query = query.where(DailyQuote.date >= dt.date.fromisoformat(start_date))
|
||
if end_date:
|
||
query = query.where(DailyQuote.date <= dt.date.fromisoformat(end_date))
|
||
|
||
query = query.order_by(DailyQuote.date)
|
||
rows = s.execute(query).all()
|
||
|
||
if not rows:
|
||
return {"ok": False, "msg": "无数据"}
|
||
|
||
engine = BacktestEngine(initial_capital, commission)
|
||
|
||
# 逐日回测
|
||
for row in rows:
|
||
date, close, volume = row
|
||
data = {"close": float(close), "volume": int(volume)}
|
||
strategy.on_data(engine, date, data)
|
||
|
||
# 计算基准(买入持有)
|
||
bench_curve = []
|
||
first_close = rows[0][1]
|
||
for row in rows:
|
||
bench_curve.append(float(row[1]) / float(first_close) * initial_capital)
|
||
|
||
metrics = engine.get_metrics()
|
||
|
||
# 交易明细
|
||
trades_detail = [{
|
||
"entry_date": p.entry_date.isoformat(),
|
||
"exit_date": p.exit_date.isoformat() if p.exit_date else "",
|
||
"entry_price": round(p.entry_price, 2),
|
||
"exit_price": round(p.exit_price, 2) if p.exit_price else 0,
|
||
"shares": round(p.shares, 2),
|
||
"hold_days": p.hold_days,
|
||
"pnl": round(p.pnl, 2),
|
||
"pnl_pct": round(p.pnl_pct, 2),
|
||
"reason": p.reason
|
||
} for p in engine.closed_positions]
|
||
|
||
return {
|
||
"ok": True,
|
||
"symbol": symbol,
|
||
"strategy": strategy.name,
|
||
"dates": [d.isoformat() for d in engine.dates],
|
||
"equity": [round(e, 2) for e in engine.equity_curve],
|
||
"bench": [round(b, 2) for b in bench_curve],
|
||
"metrics": metrics,
|
||
"trades": trades_detail,
|
||
"initial_capital": initial_capital,
|
||
}
|
||
|
||
|
||
def optimize_parameters(symbol: str,
|
||
param_grid: Dict[str, List],
|
||
strategy_class: type,
|
||
metric: str = "sharpe_ratio") -> List[Dict[str, Any]]:
|
||
"""参数网格优化
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
param_grid: 参数网格,如 {"fast": [3,5,10], "slow": [10,20,30]}
|
||
strategy_class: 策略类
|
||
metric: 优化目标指标
|
||
|
||
Returns:
|
||
优化结果列表,按指标降序排列
|
||
"""
|
||
import itertools
|
||
|
||
keys = list(param_grid.keys())
|
||
values = list(param_grid.values())
|
||
|
||
results = []
|
||
|
||
# 遍历所有参数组合
|
||
for combo in itertools.product(*values):
|
||
params = dict(zip(keys, combo))
|
||
|
||
try:
|
||
strategy = strategy_class(**params)
|
||
result = run_advanced_backtest(symbol, strategy)
|
||
|
||
if result["ok"]:
|
||
results.append({
|
||
"params": params,
|
||
"metrics": result["metrics"],
|
||
metric: result["metrics"].get(metric, 0)
|
||
})
|
||
except Exception as e:
|
||
print(f"优化失败 {params}: {e}")
|
||
continue
|
||
|
||
# 按目标指标排序
|
||
results.sort(key=lambda x: x[metric], reverse=True)
|
||
return results
|
||
|
||
|
||
def compare_strategies(symbol: str,
|
||
strategies: List[Strategy],
|
||
initial_capital: float = 100000.0) -> Dict[str, Any]:
|
||
"""策略对比
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
strategies: 策略列表
|
||
initial_capital: 初始资金
|
||
|
||
Returns:
|
||
对比结果
|
||
"""
|
||
results = []
|
||
|
||
for strategy in strategies:
|
||
result = run_advanced_backtest(symbol, strategy, initial_capital=initial_capital)
|
||
if result["ok"]:
|
||
results.append({
|
||
"strategy": strategy.name,
|
||
"equity": result["equity"],
|
||
"metrics": result["metrics"]
|
||
})
|
||
|
||
return {
|
||
"ok": True,
|
||
"symbol": symbol,
|
||
"dates": result["dates"] if results else [],
|
||
"strategies": results
|
||
}
|