195 lines
7.7 KiB
Python
195 lines
7.7 KiB
Python
"""信号历史胜率回测 + AI 预测留痕核验。
|
||
|
||
- compute_signal_stats: 对全市场历史日线回测各技术信号「N 日后上涨概率/平均收益」,
|
||
作为 AI 证据链『历史命中率』的客观依据。
|
||
- record_prediction / verify_predictions / accuracy: 记录每次 AI 诊断,N 日后核验真实涨跌,
|
||
形成可回溯的『实测准确率』。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import datetime as dt
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sqlalchemy import select, func, distinct
|
||
|
||
from db import get_session
|
||
from models import DailyQuote, StockMetric, SignalStat, Prediction
|
||
|
||
SIGNAL_DEFS = {
|
||
"ma_bull": "均线多头排列(MA5>MA10>MA20)",
|
||
"macd_gold": "MACD 金叉",
|
||
"up_streak3": "三连阳",
|
||
"vol_breakout": "放量上涨(量比>2)",
|
||
"rsi_oversold": "RSI 超卖(<30)抄底",
|
||
"new_high60": "创60日新高",
|
||
}
|
||
|
||
|
||
def _indicators(df: pd.DataFrame) -> pd.DataFrame:
|
||
c = df["close"]
|
||
df["ma5"] = c.rolling(5).mean()
|
||
df["ma10"] = c.rolling(10).mean()
|
||
df["ma20"] = c.rolling(20).mean()
|
||
e12 = c.ewm(span=12, adjust=False).mean()
|
||
e26 = c.ewm(span=26, adjust=False).mean()
|
||
df["dif"] = e12 - e26
|
||
df["dea"] = df["dif"].ewm(span=9, adjust=False).mean()
|
||
delta = c.diff()
|
||
gain = delta.clip(lower=0)
|
||
loss = -delta.clip(upper=0)
|
||
ag = gain.rolling(14).mean()
|
||
al = loss.rolling(14).mean()
|
||
df["rsi"] = 100 - 100 / (1 + ag / al.replace(0, np.nan))
|
||
df["vavg5"] = df["volume"].rolling(5).mean().shift(1)
|
||
df["high60"] = c.rolling(60).max()
|
||
return df
|
||
|
||
|
||
def _masks(df: pd.DataFrame) -> dict:
|
||
c = df["close"]
|
||
c1 = c.shift(1)
|
||
return {
|
||
"ma_bull": (df.ma5 > df.ma10) & (df.ma10 > df.ma20),
|
||
"macd_gold": (df.dif.shift(1) < df.dea.shift(1)) & (df.dif >= df.dea),
|
||
"up_streak3": (c > c1) & (c1 > c.shift(2)) & (c.shift(2) > c.shift(3)),
|
||
"vol_breakout": (df.volume / df.vavg5 > 2) & (c > c1),
|
||
"rsi_oversold": df.rsi < 30,
|
||
"new_high60": c >= df.high60,
|
||
}
|
||
|
||
|
||
def compute_signal_stats(sample_limit: int = 500, horizon: int = 5):
|
||
"""对样本股历史回测各信号 N 日后表现并落库。"""
|
||
with get_session() as s:
|
||
codes = [r[0] for r in s.execute(select(StockMetric.code).limit(sample_limit)).all()]
|
||
if not codes:
|
||
codes = [r[0] for r in s.execute(select(distinct(DailyQuote.code)).limit(sample_limit)).all()]
|
||
|
||
agg = {k: {"n": 0, "win": 0, "sum": 0.0} for k in SIGNAL_DEFS}
|
||
used = 0
|
||
with get_session() as s:
|
||
for code in codes:
|
||
rows = s.execute(select(DailyQuote.date, DailyQuote.close, DailyQuote.volume)
|
||
.where(DailyQuote.code == code).order_by(DailyQuote.date)).all()
|
||
if len(rows) < 80:
|
||
continue
|
||
used += 1
|
||
df = pd.DataFrame(rows, columns=["date", "close", "volume"])
|
||
df["close"] = df["close"].astype(float)
|
||
df["volume"] = df["volume"].astype(float)
|
||
df = _indicators(df)
|
||
fwd = df["close"].shift(-horizon) / df["close"] - 1
|
||
for k, m in _masks(df).items():
|
||
m = m & fwd.notna()
|
||
r = fwd[m]
|
||
if len(r):
|
||
agg[k]["n"] += int(len(r))
|
||
agg[k]["win"] += int((r > 0).sum())
|
||
agg[k]["sum"] += float(r.sum())
|
||
|
||
with get_session() as s:
|
||
for k, a in agg.items():
|
||
if a["n"] == 0:
|
||
continue
|
||
wr = round(a["win"] / a["n"] * 100, 1)
|
||
ar = round(a["sum"] / a["n"] * 100, 2)
|
||
row = s.execute(select(SignalStat).where(SignalStat.signal == k, SignalStat.horizon == horizon)).scalar_one_or_none()
|
||
if row:
|
||
row.samples, row.win_rate, row.avg_ret, row.updated_at = a["n"], wr, ar, dt.datetime.now()
|
||
else:
|
||
s.add(SignalStat(signal=k, horizon=horizon, samples=a["n"], win_rate=wr, avg_ret=ar))
|
||
s.commit()
|
||
return {"ok": True, "horizon": horizon, "sampled": used,
|
||
"result": {k: {"samples": a["n"], "win_rate": round(a["win"] / a["n"] * 100, 1) if a["n"] else None}
|
||
for k, a in agg.items()}}
|
||
|
||
|
||
def get_stats(horizon: int = 5) -> dict:
|
||
with get_session() as s:
|
||
rows = s.execute(select(SignalStat).where(SignalStat.horizon == horizon)).scalars().all()
|
||
return {r.signal: {"label": SIGNAL_DEFS.get(r.signal, r.signal), "win_rate": r.win_rate,
|
||
"avg_ret": r.avg_ret, "samples": r.samples} for r in rows}
|
||
|
||
|
||
def active_signals(m: StockMetric) -> list[str]:
|
||
"""根据最新因子快照判断当前激活的信号。"""
|
||
out = []
|
||
if m.ma_bull:
|
||
out.append("ma_bull")
|
||
if m.macd_gold:
|
||
out.append("macd_gold")
|
||
if m.up_streak >= 3:
|
||
out.append("up_streak3")
|
||
if m.vol_ratio > 2 and m.pct > 0:
|
||
out.append("vol_breakout")
|
||
if m.rsi14 < 30:
|
||
out.append("rsi_oversold")
|
||
if m.pos60 >= 0.99:
|
||
out.append("new_high60")
|
||
return out
|
||
|
||
|
||
# ---------------- 预测留痕与核验 ----------------
|
||
def record_prediction(code, name, date, score, confidence, direction, base_close, horizon=5, kind="diagnose"):
|
||
with get_session() as s:
|
||
exist = s.execute(select(Prediction).where(
|
||
Prediction.code == code, Prediction.date == date, Prediction.kind == kind)).scalar_one_or_none()
|
||
if exist:
|
||
return False
|
||
s.add(Prediction(date=date, code=code, name=name, kind=kind, score=score,
|
||
confidence=confidence, direction=direction, horizon=horizon, base_close=base_close))
|
||
s.commit()
|
||
return True
|
||
|
||
|
||
def verify_predictions():
|
||
"""对到期(已有 horizon 个交易日)的 open 预测核验真实涨跌。"""
|
||
with get_session() as s:
|
||
opens = s.execute(select(Prediction).where(Prediction.status == "open")).scalars().all()
|
||
closed = 0
|
||
for p in opens:
|
||
future = s.execute(select(DailyQuote.close).where(
|
||
DailyQuote.code == p.code, DailyQuote.date > p.date)
|
||
.order_by(DailyQuote.date).limit(p.horizon)).all()
|
||
if len(future) < p.horizon:
|
||
continue
|
||
end = float(future[-1][0])
|
||
ret = (end / p.base_close - 1) * 100 if p.base_close else 0.0
|
||
if p.direction == "up":
|
||
hit = ret > 0
|
||
elif p.direction == "down":
|
||
hit = ret < 0
|
||
else:
|
||
hit = abs(ret) < 2
|
||
p.actual_ret = round(ret, 2)
|
||
p.hit = bool(hit)
|
||
p.status = "closed"
|
||
closed += 1
|
||
s.commit()
|
||
return {"ok": True, "closed": closed}
|
||
|
||
|
||
def accuracy():
|
||
with get_session() as s:
|
||
rows = s.execute(select(Prediction).where(Prediction.status == "closed")).scalars().all()
|
||
opens = s.execute(select(func.count()).select_from(Prediction).where(Prediction.status == "open")).scalar()
|
||
n = len(rows)
|
||
hits = sum(1 for r in rows if r.hit)
|
||
by_dir = {}
|
||
for r in rows:
|
||
d = by_dir.setdefault(r.direction, {"n": 0, "hit": 0})
|
||
d["n"] += 1
|
||
d["hit"] += 1 if r.hit else 0
|
||
recent = sorted(rows, key=lambda r: (r.date, r.id), reverse=True)[:25]
|
||
return {
|
||
"closed": n, "open": opens or 0,
|
||
"hit_rate": round(hits / n * 100, 1) if n else None,
|
||
"avg_ret": round(sum(r.actual_ret for r in rows) / n, 2) if n else None,
|
||
"by_direction": {k: {"n": v["n"], "hit_rate": round(v["hit"] / v["n"] * 100, 1)}
|
||
for k, v in by_dir.items()},
|
||
"recent": [{"date": r.date.isoformat(), "code": r.code, "name": r.name,
|
||
"direction": r.direction, "score": r.score, "confidence": r.confidence,
|
||
"actual_ret": r.actual_ret, "hit": r.hit, "horizon": r.horizon} for r in recent],
|
||
}
|