claude强化功能

This commit is contained in:
2026-06-14 11:54:45 +08:00
parent cc8dff4e57
commit e524a3589a
43 changed files with 13421 additions and 73 deletions

View File

@@ -0,0 +1,499 @@
"""增强版回测引擎 — 多因子策略、仓位管理、参数优化。
支持功能:
1. 多因子组合策略(技术+基本面)
2. 仓位管理(固定、金字塔、凯利公式)
3. 止损止盈
4. 参数网格优化
5. 完整指标(夏普比率、最大回撤、卡玛比率等)
6. 交易明细导出
"""
import datetime as dt
from typing import Dict, List, Any, Optional, Callable
import numpy as np
from sqlalchemy import select
from db import get_session
from models import DailyQuote, StockMetric
class Position:
"""持仓记录"""
def __init__(self, date, price, shares, reason=""):
self.entry_date = date
self.entry_price = price
self.shares = shares
self.reason = reason
self.exit_date = None
self.exit_price = None
self.pnl = 0.0
self.pnl_pct = 0.0
self.hold_days = 0
class BacktestEngine:
"""增强回测引擎"""
def __init__(self, initial_capital: float = 100000.0, commission: float = 0.0005):
self.initial_capital = initial_capital
self.commission = commission
# 账户状态
self.cash = initial_capital
self.positions: List[Position] = []
self.closed_positions: List[Position] = []
# 净值曲线
self.equity_curve = []
self.dates = []
# 统计
self.trades = 0
self.wins = 0
self.total_pnl = 0.0
def get_position_value(self, price: float) -> float:
"""计算持仓市值"""
return sum(p.shares * price for p in self.positions)
def get_total_value(self, price: float) -> float:
"""计算总资产"""
return self.cash + self.get_position_value(price)
def buy(self, date, price: float, size: float, reason: str = ""):
"""买入
Args:
date: 交易日期
price: 买入价格
size: 仓位大小0-1相对于当前可用资金
reason: 买入理由
"""
if size <= 0 or size > 1:
return False
cost = self.cash * size
commission_fee = cost * self.commission
net_cost = cost - commission_fee
if net_cost <= 0:
return False
shares = net_cost / price
self.cash -= cost
pos = Position(date, price, shares, reason)
self.positions.append(pos)
self.trades += 1
return True
def sell(self, date, price: float, size: float = 1.0, reason: str = ""):
"""卖出
Args:
date: 交易日期
price: 卖出价格
size: 卖出比例0-1相对于持仓
reason: 卖出理由
"""
if not self.positions or size <= 0 or size > 1:
return False
# 按先进先出卖出
remaining = size
sold_positions = []
for pos in self.positions[:]:
if remaining <= 0:
break
sell_ratio = min(remaining, 1.0)
sell_shares = pos.shares * sell_ratio
proceeds = sell_shares * price
commission_fee = proceeds * self.commission
net_proceeds = proceeds - commission_fee
self.cash += net_proceeds
# 更新持仓
pos.shares -= sell_shares
if pos.shares < 0.01: # 清仓
pos.exit_date = date
pos.exit_price = price
pos.hold_days = (date - pos.entry_date).days
pos.pnl = (price - pos.entry_price) * (sell_shares / sell_ratio)
pos.pnl_pct = (price / pos.entry_price - 1) * 100
self.closed_positions.append(pos)
self.positions.remove(pos)
if pos.pnl > 0:
self.wins += 1
self.total_pnl += pos.pnl
remaining -= sell_ratio
sold_positions.append((pos, sell_shares))
return True
def record_state(self, date, price: float):
"""记录当前状态"""
self.dates.append(date)
self.equity_curve.append(self.get_total_value(price))
def get_metrics(self) -> Dict[str, Any]:
"""计算完整指标"""
if not self.equity_curve:
return {}
equity = np.array(self.equity_curve)
returns = np.diff(equity) / equity[:-1]
# 基础指标
total_return = (equity[-1] / equity[0] - 1) * 100
# 最大回撤
peak = np.maximum.accumulate(equity)
drawdown = (peak - equity) / peak
max_drawdown = np.max(drawdown) * 100
# 夏普比率年化假设252个交易日
if len(returns) > 1 and np.std(returns) > 0:
sharpe = np.mean(returns) / np.std(returns) * np.sqrt(252)
else:
sharpe = 0.0
# 卡玛比率(收益/最大回撤)
calmar = total_return / max_drawdown if max_drawdown > 0 else 0.0
# 胜率
closed = len(self.closed_positions)
win_rate = (self.wins / closed * 100) if closed > 0 else 0.0
# 盈亏比
winning_trades = [p.pnl for p in self.closed_positions if p.pnl > 0]
losing_trades = [abs(p.pnl) for p in self.closed_positions if p.pnl < 0]
avg_win = np.mean(winning_trades) if winning_trades else 0.0
avg_loss = np.mean(losing_trades) if losing_trades else 0.0
profit_factor = avg_win / avg_loss if avg_loss > 0 else 0.0
# 持仓天数
hold_days = [p.hold_days for p in self.closed_positions]
avg_hold = np.mean(hold_days) if hold_days else 0.0
return {
"total_return": round(total_return, 2),
"max_drawdown": round(max_drawdown, 2),
"sharpe_ratio": round(sharpe, 3),
"calmar_ratio": round(calmar, 3),
"trades": self.trades,
"closed_trades": closed,
"win_rate": round(win_rate, 1),
"profit_factor": round(profit_factor, 2),
"avg_win": round(avg_win, 2),
"avg_loss": round(avg_loss, 2),
"avg_hold_days": round(avg_hold, 1),
"total_pnl": round(self.total_pnl, 2),
}
class Strategy:
"""策略基类"""
def __init__(self, name: str):
self.name = name
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
"""每日回调"""
raise NotImplementedError
class MAStrategy(Strategy):
"""均线交叉策略(增强版)"""
def __init__(self, fast: int = 5, slow: int = 20,
position_size: float = 1.0,
stop_loss: float = 0.0,
take_profit: float = 0.0):
super().__init__(f"MA{fast}/{slow}")
self.fast = fast
self.slow = slow
self.position_size = position_size
self.stop_loss = stop_loss # 止损比例
self.take_profit = take_profit # 止盈比例
self.ma_fast_history = []
self.ma_slow_history = []
self.close_history = []
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
close = data["close"]
self.close_history.append(close)
# 计算均线
if len(self.close_history) >= self.fast:
self.ma_fast_history.append(np.mean(self.close_history[-self.fast:]))
else:
self.ma_fast_history.append(None)
if len(self.close_history) >= self.slow:
self.ma_slow_history.append(np.mean(self.close_history[-self.slow:]))
else:
self.ma_slow_history.append(None)
if len(self.ma_fast_history) < 2:
engine.record_state(date, close)
return
maf_curr = self.ma_fast_history[-1]
maf_prev = self.ma_fast_history[-2]
mas_curr = self.ma_slow_history[-1]
mas_prev = self.ma_slow_history[-2]
if maf_curr is None or mas_curr is None:
engine.record_state(date, close)
return
# 止损止盈检查
if engine.positions:
for pos in engine.positions[:]:
pnl_pct = (close / pos.entry_price - 1) * 100
# 止损
if self.stop_loss > 0 and pnl_pct <= -self.stop_loss:
engine.sell(date, close, 1.0, f"止损 {pnl_pct:.2f}%")
# 止盈
elif self.take_profit > 0 and pnl_pct >= self.take_profit:
engine.sell(date, close, 1.0, f"止盈 {pnl_pct:.2f}%")
# 金叉买入
if maf_prev <= mas_prev and maf_curr > mas_curr:
if not engine.positions:
engine.buy(date, close, self.position_size, "金叉")
# 死叉卖出
elif maf_prev >= mas_prev and maf_curr < mas_curr:
if engine.positions:
engine.sell(date, close, 1.0, "死叉")
engine.record_state(date, close)
class MultiFactorStrategy(Strategy):
"""多因子策略"""
def __init__(self, position_size: float = 1.0):
super().__init__("多因子")
self.position_size = position_size
self.close_history = []
self.volume_history = []
def calculate_rsi(self, n: int = 14) -> Optional[float]:
"""计算RSI"""
if len(self.close_history) < n + 1:
return None
changes = np.diff(self.close_history[-n-1:])
gains = np.where(changes > 0, changes, 0)
losses = np.where(changes < 0, -changes, 0)
avg_gain = np.mean(gains)
avg_loss = np.mean(losses)
if avg_loss == 0:
return 100.0
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
def on_data(self, engine: BacktestEngine, date, data: Dict[str, Any]) -> None:
close = data["close"]
volume = data.get("volume", 0)
self.close_history.append(close)
self.volume_history.append(volume)
if len(self.close_history) < 30:
engine.record_state(date, close)
return
# 计算因子
ma5 = np.mean(self.close_history[-5:])
ma20 = np.mean(self.close_history[-20:])
rsi = self.calculate_rsi(14)
# 量比
vol_avg = np.mean(self.volume_history[-20:-1])
vol_ratio = volume / vol_avg if vol_avg > 0 else 1.0
# 买入信号MA5 > MA20, RSI < 70, 放量
buy_signal = (ma5 > ma20 and
rsi is not None and rsi < 70 and
vol_ratio > 1.5)
# 卖出信号MA5 < MA20 或 RSI > 80
sell_signal = (ma5 < ma20 or
(rsi is not None and rsi > 80))
if buy_signal and not engine.positions:
engine.buy(date, close, self.position_size, "多因子买入")
if sell_signal and engine.positions:
engine.sell(date, close, 1.0, "多因子卖出")
engine.record_state(date, close)
def run_advanced_backtest(symbol: str,
strategy: Strategy,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
initial_capital: float = 100000.0,
commission: float = 0.0005) -> Dict[str, Any]:
"""运行增强回测
Args:
symbol: 股票代码
strategy: 策略实例
start_date: 开始日期
end_date: 结束日期
initial_capital: 初始资金
commission: 手续费率
Returns:
回测结果
"""
with get_session() as s:
query = select(DailyQuote.date, DailyQuote.close, DailyQuote.volume).where(
DailyQuote.code == symbol
)
if start_date:
query = query.where(DailyQuote.date >= dt.date.fromisoformat(start_date))
if end_date:
query = query.where(DailyQuote.date <= dt.date.fromisoformat(end_date))
query = query.order_by(DailyQuote.date)
rows = s.execute(query).all()
if not rows:
return {"ok": False, "msg": "无数据"}
engine = BacktestEngine(initial_capital, commission)
# 逐日回测
for row in rows:
date, close, volume = row
data = {"close": float(close), "volume": int(volume)}
strategy.on_data(engine, date, data)
# 计算基准(买入持有)
bench_curve = []
first_close = rows[0][1]
for row in rows:
bench_curve.append(float(row[1]) / float(first_close) * initial_capital)
metrics = engine.get_metrics()
# 交易明细
trades_detail = [{
"entry_date": p.entry_date.isoformat(),
"exit_date": p.exit_date.isoformat() if p.exit_date else "",
"entry_price": round(p.entry_price, 2),
"exit_price": round(p.exit_price, 2) if p.exit_price else 0,
"shares": round(p.shares, 2),
"hold_days": p.hold_days,
"pnl": round(p.pnl, 2),
"pnl_pct": round(p.pnl_pct, 2),
"reason": p.reason
} for p in engine.closed_positions]
return {
"ok": True,
"symbol": symbol,
"strategy": strategy.name,
"dates": [d.isoformat() for d in engine.dates],
"equity": [round(e, 2) for e in engine.equity_curve],
"bench": [round(b, 2) for b in bench_curve],
"metrics": metrics,
"trades": trades_detail,
"initial_capital": initial_capital,
}
def optimize_parameters(symbol: str,
param_grid: Dict[str, List],
strategy_class: type,
metric: str = "sharpe_ratio") -> List[Dict[str, Any]]:
"""参数网格优化
Args:
symbol: 股票代码
param_grid: 参数网格,如 {"fast": [3,5,10], "slow": [10,20,30]}
strategy_class: 策略类
metric: 优化目标指标
Returns:
优化结果列表,按指标降序排列
"""
import itertools
keys = list(param_grid.keys())
values = list(param_grid.values())
results = []
# 遍历所有参数组合
for combo in itertools.product(*values):
params = dict(zip(keys, combo))
try:
strategy = strategy_class(**params)
result = run_advanced_backtest(symbol, strategy)
if result["ok"]:
results.append({
"params": params,
"metrics": result["metrics"],
metric: result["metrics"].get(metric, 0)
})
except Exception as e:
print(f"优化失败 {params}: {e}")
continue
# 按目标指标排序
results.sort(key=lambda x: x[metric], reverse=True)
return results
def compare_strategies(symbol: str,
strategies: List[Strategy],
initial_capital: float = 100000.0) -> Dict[str, Any]:
"""策略对比
Args:
symbol: 股票代码
strategies: 策略列表
initial_capital: 初始资金
Returns:
对比结果
"""
results = []
for strategy in strategies:
result = run_advanced_backtest(symbol, strategy, initial_capital=initial_capital)
if result["ok"]:
results.append({
"strategy": strategy.name,
"equity": result["equity"],
"metrics": result["metrics"]
})
return {
"ok": True,
"symbol": symbol,
"dates": result["dates"] if results else [],
"strategies": results
}