claude强化功能
This commit is contained in:
499
backend/backtest_advanced.py
Normal file
499
backend/backtest_advanced.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""增强版回测引擎 — 多因子策略、仓位管理、参数优化。
|
||||
|
||||
支持功能:
|
||||
1. 多因子组合策略(技术+基本面)
|
||||
2. 仓位管理(固定、金字塔、凯利公式)
|
||||
3. 止损止盈
|
||||
4. 参数网格优化
|
||||
5. 完整指标(夏普比率、最大回撤、卡玛比率等)
|
||||
6. 交易明细导出
|
||||
"""
|
||||
import datetime as dt
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
|
||||
from db import get_session
|
||||
from models import DailyQuote, StockMetric
|
||||
|
||||
|
||||
class Position:
|
||||
"""持仓记录"""
|
||||
def __init__(self, date, price, shares, reason=""):
|
||||
self.entry_date = date
|
||||
self.entry_price = price
|
||||
self.shares = shares
|
||||
self.reason = reason
|
||||
self.exit_date = None
|
||||
self.exit_price = None
|
||||
self.pnl = 0.0
|
||||
self.pnl_pct = 0.0
|
||||
self.hold_days = 0
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""增强回测引擎"""
|
||||
|
||||
def __init__(self, initial_capital: float = 100000.0, commission: float = 0.0005):
|
||||
self.initial_capital = initial_capital
|
||||
self.commission = commission
|
||||
|
||||
# 账户状态
|
||||
self.cash = initial_capital
|
||||
self.positions: List[Position] = []
|
||||
self.closed_positions: List[Position] = []
|
||||
|
||||
# 净值曲线
|
||||
self.equity_curve = []
|
||||
self.dates = []
|
||||
|
||||
# 统计
|
||||
self.trades = 0
|
||||
self.wins = 0
|
||||
self.total_pnl = 0.0
|
||||
|
||||
def get_position_value(self, price: float) -> float:
|
||||
"""计算持仓市值"""
|
||||
return sum(p.shares * price for p in self.positions)
|
||||
|
||||
def get_total_value(self, price: float) -> float:
|
||||
"""计算总资产"""
|
||||
return self.cash + self.get_position_value(price)
|
||||
|
||||
def buy(self, date, price: float, size: float, reason: str = ""):
|
||||
"""买入
|
||||
|
||||
Args:
|
||||
date: 交易日期
|
||||
price: 买入价格
|
||||
size: 仓位大小(0-1),相对于当前可用资金
|
||||
reason: 买入理由
|
||||
"""
|
||||
if size <= 0 or size > 1:
|
||||
return False
|
||||
|
||||
cost = self.cash * size
|
||||
commission_fee = cost * self.commission
|
||||
net_cost = cost - commission_fee
|
||||
|
||||
if net_cost <= 0:
|
||||
return False
|
||||
|
||||
shares = net_cost / price
|
||||
self.cash -= cost
|
||||
|
||||
pos = Position(date, price, shares, reason)
|
||||
self.positions.append(pos)
|
||||
self.trades += 1
|
||||
return True
|
||||
|
||||
def sell(self, date, price: float, size: float = 1.0, reason: str = ""):
|
||||
"""卖出
|
||||
|
||||
Args:
|
||||
date: 交易日期
|
||||
price: 卖出价格
|
||||
size: 卖出比例(0-1),相对于持仓
|
||||
reason: 卖出理由
|
||||
"""
|
||||
if not self.positions or size <= 0 or size > 1:
|
||||
return False
|
||||
|
||||
# 按先进先出卖出
|
||||
remaining = size
|
||||
sold_positions = []
|
||||
|
||||
for pos in self.positions[:]:
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
sell_ratio = min(remaining, 1.0)
|
||||
sell_shares = pos.shares * sell_ratio
|
||||
proceeds = sell_shares * price
|
||||
commission_fee = proceeds * self.commission
|
||||
net_proceeds = proceeds - commission_fee
|
||||
|
||||
self.cash += net_proceeds
|
||||
|
||||
# 更新持仓
|
||||
pos.shares -= sell_shares
|
||||
if pos.shares < 0.01: # 清仓
|
||||
pos.exit_date = date
|
||||
pos.exit_price = price
|
||||
pos.hold_days = (date - pos.entry_date).days
|
||||
pos.pnl = (price - pos.entry_price) * (sell_shares / sell_ratio)
|
||||
pos.pnl_pct = (price / pos.entry_price - 1) * 100
|
||||
|
||||
self.closed_positions.append(pos)
|
||||
self.positions.remove(pos)
|
||||
|
||||
if pos.pnl > 0:
|
||||
self.wins += 1
|
||||
self.total_pnl += pos.pnl
|
||||
|
||||
remaining -= sell_ratio
|
||||
sold_positions.append((pos, sell_shares))
|
||||
|
||||
return True
|
||||
|
||||
def record_state(self, date, price: float):
|
||||
"""记录当前状态"""
|
||||
self.dates.append(date)
|
||||
self.equity_curve.append(self.get_total_value(price))
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""计算完整指标"""
|
||||
if not self.equity_curve:
|
||||
return {}
|
||||
|
||||
equity = np.array(self.equity_curve)
|
||||
returns = np.diff(equity) / equity[:-1]
|
||||
|
||||
# 基础指标
|
||||
total_return = (equity[-1] / equity[0] - 1) * 100
|
||||
|
||||
# 最大回撤
|
||||
peak = np.maximum.accumulate(equity)
|
||||
drawdown = (peak - equity) / peak
|
||||
max_drawdown = np.max(drawdown) * 100
|
||||
|
||||
# 夏普比率(年化,假设252个交易日)
|
||||
if len(returns) > 1 and np.std(returns) > 0:
|
||||
sharpe = np.mean(returns) / np.std(returns) * np.sqrt(252)
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
# 卡玛比率(收益/最大回撤)
|
||||
calmar = total_return / max_drawdown if max_drawdown > 0 else 0.0
|
||||
|
||||
# 胜率
|
||||
closed = len(self.closed_positions)
|
||||
win_rate = (self.wins / closed * 100) if closed > 0 else 0.0
|
||||
|
||||
# 盈亏比
|
||||
winning_trades = [p.pnl for p in self.closed_positions if p.pnl > 0]
|
||||
losing_trades = [abs(p.pnl) for p in self.closed_positions if p.pnl < 0]
|
||||
avg_win = np.mean(winning_trades) if winning_trades else 0.0
|
||||
avg_loss = np.mean(losing_trades) if losing_trades else 0.0
|
||||
profit_factor = avg_win / avg_loss if avg_loss > 0 else 0.0
|
||||
|
||||
# 持仓天数
|
||||
hold_days = [p.hold_days for p in self.closed_positions]
|
||||
avg_hold = np.mean(hold_days) if hold_days else 0.0
|
||||
|
||||
return {
|
||||
"total_return": round(total_return, 2),
|
||||
"max_drawdown": round(max_drawdown, 2),
|
||||
"sharpe_ratio": round(sharpe, 3),
|
||||
"calmar_ratio": round(calmar, 3),
|
||||
"trades": self.trades,
|
||||
"closed_trades": closed,
|
||||
"win_rate": round(win_rate, 1),
|
||||
"profit_factor": round(profit_factor, 2),
|
||||
"avg_win": round(avg_win, 2),
|
||||
"avg_loss": round(avg_loss, 2),
|
||||
"avg_hold_days": round(avg_hold, 1),
|
||||
"total_pnl": round(self.total_pnl, 2),
|
||||
}
|
||||
|
||||
|
||||
class Strategy:
|
||||
"""策略基类"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||||
"""每日回调"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MAStrategy(Strategy):
|
||||
"""均线交叉策略(增强版)"""
|
||||
|
||||
def __init__(self, fast: int = 5, slow: int = 20,
|
||||
position_size: float = 1.0,
|
||||
stop_loss: float = 0.0,
|
||||
take_profit: float = 0.0):
|
||||
super().__init__(f"MA{fast}/{slow}")
|
||||
self.fast = fast
|
||||
self.slow = slow
|
||||
self.position_size = position_size
|
||||
self.stop_loss = stop_loss # 止损比例
|
||||
self.take_profit = take_profit # 止盈比例
|
||||
|
||||
self.ma_fast_history = []
|
||||
self.ma_slow_history = []
|
||||
self.close_history = []
|
||||
|
||||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||||
close = data["close"]
|
||||
self.close_history.append(close)
|
||||
|
||||
# 计算均线
|
||||
if len(self.close_history) >= self.fast:
|
||||
self.ma_fast_history.append(np.mean(self.close_history[-self.fast:]))
|
||||
else:
|
||||
self.ma_fast_history.append(None)
|
||||
|
||||
if len(self.close_history) >= self.slow:
|
||||
self.ma_slow_history.append(np.mean(self.close_history[-self.slow:]))
|
||||
else:
|
||||
self.ma_slow_history.append(None)
|
||||
|
||||
if len(self.ma_fast_history) < 2:
|
||||
engine.record_state(date, close)
|
||||
return
|
||||
|
||||
maf_curr = self.ma_fast_history[-1]
|
||||
maf_prev = self.ma_fast_history[-2]
|
||||
mas_curr = self.ma_slow_history[-1]
|
||||
mas_prev = self.ma_slow_history[-2]
|
||||
|
||||
if maf_curr is None or mas_curr is None:
|
||||
engine.record_state(date, close)
|
||||
return
|
||||
|
||||
# 止损止盈检查
|
||||
if engine.positions:
|
||||
for pos in engine.positions[:]:
|
||||
pnl_pct = (close / pos.entry_price - 1) * 100
|
||||
|
||||
# 止损
|
||||
if self.stop_loss > 0 and pnl_pct <= -self.stop_loss:
|
||||
engine.sell(date, close, 1.0, f"止损 {pnl_pct:.2f}%")
|
||||
# 止盈
|
||||
elif self.take_profit > 0 and pnl_pct >= self.take_profit:
|
||||
engine.sell(date, close, 1.0, f"止盈 {pnl_pct:.2f}%")
|
||||
|
||||
# 金叉买入
|
||||
if maf_prev <= mas_prev and maf_curr > mas_curr:
|
||||
if not engine.positions:
|
||||
engine.buy(date, close, self.position_size, "金叉")
|
||||
|
||||
# 死叉卖出
|
||||
elif maf_prev >= mas_prev and maf_curr < mas_curr:
|
||||
if engine.positions:
|
||||
engine.sell(date, close, 1.0, "死叉")
|
||||
|
||||
engine.record_state(date, close)
|
||||
|
||||
|
||||
class MultiFactorStrategy(Strategy):
|
||||
"""多因子策略"""
|
||||
|
||||
def __init__(self, position_size: float = 1.0):
|
||||
super().__init__("多因子")
|
||||
self.position_size = position_size
|
||||
self.close_history = []
|
||||
self.volume_history = []
|
||||
|
||||
def calculate_rsi(self, n: int = 14) -> Optional[float]:
|
||||
"""计算RSI"""
|
||||
if len(self.close_history) < n + 1:
|
||||
return None
|
||||
|
||||
changes = np.diff(self.close_history[-n-1:])
|
||||
gains = np.where(changes > 0, changes, 0)
|
||||
losses = np.where(changes < 0, -changes, 0)
|
||||
|
||||
avg_gain = np.mean(gains)
|
||||
avg_loss = np.mean(losses)
|
||||
|
||||
if avg_loss == 0:
|
||||
return 100.0
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi
|
||||
|
||||
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
|
||||
close = data["close"]
|
||||
volume = data.get("volume", 0)
|
||||
|
||||
self.close_history.append(close)
|
||||
self.volume_history.append(volume)
|
||||
|
||||
if len(self.close_history) < 30:
|
||||
engine.record_state(date, close)
|
||||
return
|
||||
|
||||
# 计算因子
|
||||
ma5 = np.mean(self.close_history[-5:])
|
||||
ma20 = np.mean(self.close_history[-20:])
|
||||
rsi = self.calculate_rsi(14)
|
||||
|
||||
# 量比
|
||||
vol_avg = np.mean(self.volume_history[-20:-1])
|
||||
vol_ratio = volume / vol_avg if vol_avg > 0 else 1.0
|
||||
|
||||
# 买入信号:MA5 > MA20, RSI < 70, 放量
|
||||
buy_signal = (ma5 > ma20 and
|
||||
rsi is not None and rsi < 70 and
|
||||
vol_ratio > 1.5)
|
||||
|
||||
# 卖出信号:MA5 < MA20 或 RSI > 80
|
||||
sell_signal = (ma5 < ma20 or
|
||||
(rsi is not None and rsi > 80))
|
||||
|
||||
if buy_signal and not engine.positions:
|
||||
engine.buy(date, close, self.position_size, "多因子买入")
|
||||
|
||||
if sell_signal and engine.positions:
|
||||
engine.sell(date, close, 1.0, "多因子卖出")
|
||||
|
||||
engine.record_state(date, close)
|
||||
|
||||
|
||||
def run_advanced_backtest(symbol: str,
|
||||
strategy: Strategy,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
initial_capital: float = 100000.0,
|
||||
commission: float = 0.0005) -> Dict[str, Any]:
|
||||
"""运行增强回测
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
strategy: 策略实例
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
initial_capital: 初始资金
|
||||
commission: 手续费率
|
||||
|
||||
Returns:
|
||||
回测结果
|
||||
"""
|
||||
with get_session() as s:
|
||||
query = select(DailyQuote.date, DailyQuote.close, DailyQuote.volume).where(
|
||||
DailyQuote.code == symbol
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.where(DailyQuote.date >= dt.date.fromisoformat(start_date))
|
||||
if end_date:
|
||||
query = query.where(DailyQuote.date <= dt.date.fromisoformat(end_date))
|
||||
|
||||
query = query.order_by(DailyQuote.date)
|
||||
rows = s.execute(query).all()
|
||||
|
||||
if not rows:
|
||||
return {"ok": False, "msg": "无数据"}
|
||||
|
||||
engine = BacktestEngine(initial_capital, commission)
|
||||
|
||||
# 逐日回测
|
||||
for row in rows:
|
||||
date, close, volume = row
|
||||
data = {"close": float(close), "volume": int(volume)}
|
||||
strategy.on_data(engine, date, data)
|
||||
|
||||
# 计算基准(买入持有)
|
||||
bench_curve = []
|
||||
first_close = rows[0][1]
|
||||
for row in rows:
|
||||
bench_curve.append(float(row[1]) / float(first_close) * initial_capital)
|
||||
|
||||
metrics = engine.get_metrics()
|
||||
|
||||
# 交易明细
|
||||
trades_detail = [{
|
||||
"entry_date": p.entry_date.isoformat(),
|
||||
"exit_date": p.exit_date.isoformat() if p.exit_date else "",
|
||||
"entry_price": round(p.entry_price, 2),
|
||||
"exit_price": round(p.exit_price, 2) if p.exit_price else 0,
|
||||
"shares": round(p.shares, 2),
|
||||
"hold_days": p.hold_days,
|
||||
"pnl": round(p.pnl, 2),
|
||||
"pnl_pct": round(p.pnl_pct, 2),
|
||||
"reason": p.reason
|
||||
} for p in engine.closed_positions]
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"symbol": symbol,
|
||||
"strategy": strategy.name,
|
||||
"dates": [d.isoformat() for d in engine.dates],
|
||||
"equity": [round(e, 2) for e in engine.equity_curve],
|
||||
"bench": [round(b, 2) for b in bench_curve],
|
||||
"metrics": metrics,
|
||||
"trades": trades_detail,
|
||||
"initial_capital": initial_capital,
|
||||
}
|
||||
|
||||
|
||||
def optimize_parameters(symbol: str,
|
||||
param_grid: Dict[str, List],
|
||||
strategy_class: type,
|
||||
metric: str = "sharpe_ratio") -> List[Dict[str, Any]]:
|
||||
"""参数网格优化
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
param_grid: 参数网格,如 {"fast": [3,5,10], "slow": [10,20,30]}
|
||||
strategy_class: 策略类
|
||||
metric: 优化目标指标
|
||||
|
||||
Returns:
|
||||
优化结果列表,按指标降序排列
|
||||
"""
|
||||
import itertools
|
||||
|
||||
keys = list(param_grid.keys())
|
||||
values = list(param_grid.values())
|
||||
|
||||
results = []
|
||||
|
||||
# 遍历所有参数组合
|
||||
for combo in itertools.product(*values):
|
||||
params = dict(zip(keys, combo))
|
||||
|
||||
try:
|
||||
strategy = strategy_class(**params)
|
||||
result = run_advanced_backtest(symbol, strategy)
|
||||
|
||||
if result["ok"]:
|
||||
results.append({
|
||||
"params": params,
|
||||
"metrics": result["metrics"],
|
||||
metric: result["metrics"].get(metric, 0)
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"优化失败 {params}: {e}")
|
||||
continue
|
||||
|
||||
# 按目标指标排序
|
||||
results.sort(key=lambda x: x[metric], reverse=True)
|
||||
return results
|
||||
|
||||
|
||||
def compare_strategies(symbol: str,
|
||||
strategies: List[Strategy],
|
||||
initial_capital: float = 100000.0) -> Dict[str, Any]:
|
||||
"""策略对比
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
strategies: 策略列表
|
||||
initial_capital: 初始资金
|
||||
|
||||
Returns:
|
||||
对比结果
|
||||
"""
|
||||
results = []
|
||||
|
||||
for strategy in strategies:
|
||||
result = run_advanced_backtest(symbol, strategy, initial_capital=initial_capital)
|
||||
if result["ok"]:
|
||||
results.append({
|
||||
"strategy": strategy.name,
|
||||
"equity": result["equity"],
|
||||
"metrics": result["metrics"]
|
||||
})
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"symbol": symbol,
|
||||
"dates": result["dates"] if results else [],
|
||||
"strategies": results
|
||||
}
|
||||
Reference in New Issue
Block a user