"""信号历史胜率回测 + AI 预测留痕核验。 - compute_signal_stats: 对全市场历史日线回测各技术信号「N 日后上涨概率/平均收益」, 作为 AI 证据链『历史命中率』的客观依据。 - record_prediction / verify_predictions / accuracy: 记录每次 AI 诊断,N 日后核验真实涨跌, 形成可回溯的『实测准确率』。 """ from __future__ import annotations import datetime as dt import numpy as np import pandas as pd from sqlalchemy import select, func, distinct from db import get_session from models import DailyQuote, StockMetric, SignalStat, Prediction SIGNAL_DEFS = { "ma_bull": "均线多头排列(MA5>MA10>MA20)", "macd_gold": "MACD 金叉", "up_streak3": "三连阳", "vol_breakout": "放量上涨(量比>2)", "rsi_oversold": "RSI 超卖(<30)抄底", "new_high60": "创60日新高", } def _indicators(df: pd.DataFrame) -> pd.DataFrame: c = df["close"] df["ma5"] = c.rolling(5).mean() df["ma10"] = c.rolling(10).mean() df["ma20"] = c.rolling(20).mean() e12 = c.ewm(span=12, adjust=False).mean() e26 = c.ewm(span=26, adjust=False).mean() df["dif"] = e12 - e26 df["dea"] = df["dif"].ewm(span=9, adjust=False).mean() delta = c.diff() gain = delta.clip(lower=0) loss = -delta.clip(upper=0) ag = gain.rolling(14).mean() al = loss.rolling(14).mean() df["rsi"] = 100 - 100 / (1 + ag / al.replace(0, np.nan)) df["vavg5"] = df["volume"].rolling(5).mean().shift(1) df["high60"] = c.rolling(60).max() return df def _masks(df: pd.DataFrame) -> dict: c = df["close"] c1 = c.shift(1) return { "ma_bull": (df.ma5 > df.ma10) & (df.ma10 > df.ma20), "macd_gold": (df.dif.shift(1) < df.dea.shift(1)) & (df.dif >= df.dea), "up_streak3": (c > c1) & (c1 > c.shift(2)) & (c.shift(2) > c.shift(3)), "vol_breakout": (df.volume / df.vavg5 > 2) & (c > c1), "rsi_oversold": df.rsi < 30, "new_high60": c >= df.high60, } def compute_signal_stats(sample_limit: int = 500, horizon: int = 5): """对样本股历史回测各信号 N 日后表现并落库。""" with get_session() as s: codes = [r[0] for r in s.execute(select(StockMetric.code).limit(sample_limit)).all()] if not codes: codes = [r[0] for r in s.execute(select(distinct(DailyQuote.code)).limit(sample_limit)).all()] agg = {k: {"n": 0, "win": 0, "sum": 0.0} for k in SIGNAL_DEFS} used = 0 with get_session() as s: for code in codes: rows = s.execute(select(DailyQuote.date, DailyQuote.close, DailyQuote.volume) .where(DailyQuote.code == code).order_by(DailyQuote.date)).all() if len(rows) < 80: continue used += 1 df = pd.DataFrame(rows, columns=["date", "close", "volume"]) df["close"] = df["close"].astype(float) df["volume"] = df["volume"].astype(float) df = _indicators(df) fwd = df["close"].shift(-horizon) / df["close"] - 1 for k, m in _masks(df).items(): m = m & fwd.notna() r = fwd[m] if len(r): agg[k]["n"] += int(len(r)) agg[k]["win"] += int((r > 0).sum()) agg[k]["sum"] += float(r.sum()) with get_session() as s: for k, a in agg.items(): if a["n"] == 0: continue wr = round(a["win"] / a["n"] * 100, 1) ar = round(a["sum"] / a["n"] * 100, 2) row = s.execute(select(SignalStat).where(SignalStat.signal == k, SignalStat.horizon == horizon)).scalar_one_or_none() if row: row.samples, row.win_rate, row.avg_ret, row.updated_at = a["n"], wr, ar, dt.datetime.now() else: s.add(SignalStat(signal=k, horizon=horizon, samples=a["n"], win_rate=wr, avg_ret=ar)) s.commit() return {"ok": True, "horizon": horizon, "sampled": used, "result": {k: {"samples": a["n"], "win_rate": round(a["win"] / a["n"] * 100, 1) if a["n"] else None} for k, a in agg.items()}} def get_stats(horizon: int = 5) -> dict: with get_session() as s: rows = s.execute(select(SignalStat).where(SignalStat.horizon == horizon)).scalars().all() return {r.signal: {"label": SIGNAL_DEFS.get(r.signal, r.signal), "win_rate": r.win_rate, "avg_ret": r.avg_ret, "samples": r.samples} for r in rows} def active_signals(m: StockMetric) -> list[str]: """根据最新因子快照判断当前激活的信号。""" out = [] if m.ma_bull: out.append("ma_bull") if m.macd_gold: out.append("macd_gold") if m.up_streak >= 3: out.append("up_streak3") if m.vol_ratio > 2 and m.pct > 0: out.append("vol_breakout") if m.rsi14 < 30: out.append("rsi_oversold") if m.pos60 >= 0.99: out.append("new_high60") return out # ---------------- 预测留痕与核验 ---------------- def record_prediction(code, name, date, score, confidence, direction, base_close, horizon=5, kind="diagnose"): with get_session() as s: exist = s.execute(select(Prediction).where( Prediction.code == code, Prediction.date == date, Prediction.kind == kind)).scalar_one_or_none() if exist: return False s.add(Prediction(date=date, code=code, name=name, kind=kind, score=score, confidence=confidence, direction=direction, horizon=horizon, base_close=base_close)) s.commit() return True def verify_predictions(): """对到期(已有 horizon 个交易日)的 open 预测核验真实涨跌。""" with get_session() as s: opens = s.execute(select(Prediction).where(Prediction.status == "open")).scalars().all() closed = 0 for p in opens: future = s.execute(select(DailyQuote.close).where( DailyQuote.code == p.code, DailyQuote.date > p.date) .order_by(DailyQuote.date).limit(p.horizon)).all() if len(future) < p.horizon: continue end = float(future[-1][0]) ret = (end / p.base_close - 1) * 100 if p.base_close else 0.0 if p.direction == "up": hit = ret > 0 elif p.direction == "down": hit = ret < 0 else: hit = abs(ret) < 2 p.actual_ret = round(ret, 2) p.hit = bool(hit) p.status = "closed" closed += 1 s.commit() return {"ok": True, "closed": closed} def accuracy(): with get_session() as s: rows = s.execute(select(Prediction).where(Prediction.status == "closed")).scalars().all() opens = s.execute(select(func.count()).select_from(Prediction).where(Prediction.status == "open")).scalar() n = len(rows) hits = sum(1 for r in rows if r.hit) by_dir = {} for r in rows: d = by_dir.setdefault(r.direction, {"n": 0, "hit": 0}) d["n"] += 1 d["hit"] += 1 if r.hit else 0 recent = sorted(rows, key=lambda r: (r.date, r.id), reverse=True)[:25] return { "closed": n, "open": opens or 0, "hit_rate": round(hits / n * 100, 1) if n else None, "avg_ret": round(sum(r.actual_ret for r in rows) / n, 2) if n else None, "by_direction": {k: {"n": v["n"], "hit_rate": round(v["hit"] / v["n"] * 100, 1)} for k, v in by_dir.items()}, "recent": [{"date": r.date.isoformat(), "code": r.code, "name": r.name, "direction": r.direction, "score": r.score, "confidence": r.confidence, "actual_ret": r.actual_ret, "hit": r.hit, "horizon": r.horizon} for r in recent], }