Files
stock_cursor_v0/backend/signals.py

195 lines
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""信号历史胜率回测 + AI 预测留痕核验。
- compute_signal_stats: 对全市场历史日线回测各技术信号「N 日后上涨概率/平均收益」,
作为 AI 证据链『历史命中率』的客观依据。
- record_prediction / verify_predictions / accuracy: 记录每次 AI 诊断N 日后核验真实涨跌,
形成可回溯的『实测准确率』。
"""
from __future__ import annotations
import datetime as dt
import numpy as np
import pandas as pd
from sqlalchemy import select, func, distinct
from db import get_session
from models import DailyQuote, StockMetric, SignalStat, Prediction
SIGNAL_DEFS = {
"ma_bull": "均线多头排列(MA5>MA10>MA20)",
"macd_gold": "MACD 金叉",
"up_streak3": "三连阳",
"vol_breakout": "放量上涨(量比>2)",
"rsi_oversold": "RSI 超卖(<30)抄底",
"new_high60": "创60日新高",
}
def _indicators(df: pd.DataFrame) -> pd.DataFrame:
c = df["close"]
df["ma5"] = c.rolling(5).mean()
df["ma10"] = c.rolling(10).mean()
df["ma20"] = c.rolling(20).mean()
e12 = c.ewm(span=12, adjust=False).mean()
e26 = c.ewm(span=26, adjust=False).mean()
df["dif"] = e12 - e26
df["dea"] = df["dif"].ewm(span=9, adjust=False).mean()
delta = c.diff()
gain = delta.clip(lower=0)
loss = -delta.clip(upper=0)
ag = gain.rolling(14).mean()
al = loss.rolling(14).mean()
df["rsi"] = 100 - 100 / (1 + ag / al.replace(0, np.nan))
df["vavg5"] = df["volume"].rolling(5).mean().shift(1)
df["high60"] = c.rolling(60).max()
return df
def _masks(df: pd.DataFrame) -> dict:
c = df["close"]
c1 = c.shift(1)
return {
"ma_bull": (df.ma5 > df.ma10) & (df.ma10 > df.ma20),
"macd_gold": (df.dif.shift(1) < df.dea.shift(1)) & (df.dif >= df.dea),
"up_streak3": (c > c1) & (c1 > c.shift(2)) & (c.shift(2) > c.shift(3)),
"vol_breakout": (df.volume / df.vavg5 > 2) & (c > c1),
"rsi_oversold": df.rsi < 30,
"new_high60": c >= df.high60,
}
def compute_signal_stats(sample_limit: int = 500, horizon: int = 5):
"""对样本股历史回测各信号 N 日后表现并落库。"""
with get_session() as s:
codes = [r[0] for r in s.execute(select(StockMetric.code).limit(sample_limit)).all()]
if not codes:
codes = [r[0] for r in s.execute(select(distinct(DailyQuote.code)).limit(sample_limit)).all()]
agg = {k: {"n": 0, "win": 0, "sum": 0.0} for k in SIGNAL_DEFS}
used = 0
with get_session() as s:
for code in codes:
rows = s.execute(select(DailyQuote.date, DailyQuote.close, DailyQuote.volume)
.where(DailyQuote.code == code).order_by(DailyQuote.date)).all()
if len(rows) < 80:
continue
used += 1
df = pd.DataFrame(rows, columns=["date", "close", "volume"])
df["close"] = df["close"].astype(float)
df["volume"] = df["volume"].astype(float)
df = _indicators(df)
fwd = df["close"].shift(-horizon) / df["close"] - 1
for k, m in _masks(df).items():
m = m & fwd.notna()
r = fwd[m]
if len(r):
agg[k]["n"] += int(len(r))
agg[k]["win"] += int((r > 0).sum())
agg[k]["sum"] += float(r.sum())
with get_session() as s:
for k, a in agg.items():
if a["n"] == 0:
continue
wr = round(a["win"] / a["n"] * 100, 1)
ar = round(a["sum"] / a["n"] * 100, 2)
row = s.execute(select(SignalStat).where(SignalStat.signal == k, SignalStat.horizon == horizon)).scalar_one_or_none()
if row:
row.samples, row.win_rate, row.avg_ret, row.updated_at = a["n"], wr, ar, dt.datetime.now()
else:
s.add(SignalStat(signal=k, horizon=horizon, samples=a["n"], win_rate=wr, avg_ret=ar))
s.commit()
return {"ok": True, "horizon": horizon, "sampled": used,
"result": {k: {"samples": a["n"], "win_rate": round(a["win"] / a["n"] * 100, 1) if a["n"] else None}
for k, a in agg.items()}}
def get_stats(horizon: int = 5) -> dict:
with get_session() as s:
rows = s.execute(select(SignalStat).where(SignalStat.horizon == horizon)).scalars().all()
return {r.signal: {"label": SIGNAL_DEFS.get(r.signal, r.signal), "win_rate": r.win_rate,
"avg_ret": r.avg_ret, "samples": r.samples} for r in rows}
def active_signals(m: StockMetric) -> list[str]:
"""根据最新因子快照判断当前激活的信号。"""
out = []
if m.ma_bull:
out.append("ma_bull")
if m.macd_gold:
out.append("macd_gold")
if m.up_streak >= 3:
out.append("up_streak3")
if m.vol_ratio > 2 and m.pct > 0:
out.append("vol_breakout")
if m.rsi14 < 30:
out.append("rsi_oversold")
if m.pos60 >= 0.99:
out.append("new_high60")
return out
# ---------------- 预测留痕与核验 ----------------
def record_prediction(code, name, date, score, confidence, direction, base_close, horizon=5, kind="diagnose"):
with get_session() as s:
exist = s.execute(select(Prediction).where(
Prediction.code == code, Prediction.date == date, Prediction.kind == kind)).scalar_one_or_none()
if exist:
return False
s.add(Prediction(date=date, code=code, name=name, kind=kind, score=score,
confidence=confidence, direction=direction, horizon=horizon, base_close=base_close))
s.commit()
return True
def verify_predictions():
"""对到期(已有 horizon 个交易日)的 open 预测核验真实涨跌。"""
with get_session() as s:
opens = s.execute(select(Prediction).where(Prediction.status == "open")).scalars().all()
closed = 0
for p in opens:
future = s.execute(select(DailyQuote.close).where(
DailyQuote.code == p.code, DailyQuote.date > p.date)
.order_by(DailyQuote.date).limit(p.horizon)).all()
if len(future) < p.horizon:
continue
end = float(future[-1][0])
ret = (end / p.base_close - 1) * 100 if p.base_close else 0.0
if p.direction == "up":
hit = ret > 0
elif p.direction == "down":
hit = ret < 0
else:
hit = abs(ret) < 2
p.actual_ret = round(ret, 2)
p.hit = bool(hit)
p.status = "closed"
closed += 1
s.commit()
return {"ok": True, "closed": closed}
def accuracy():
with get_session() as s:
rows = s.execute(select(Prediction).where(Prediction.status == "closed")).scalars().all()
opens = s.execute(select(func.count()).select_from(Prediction).where(Prediction.status == "open")).scalar()
n = len(rows)
hits = sum(1 for r in rows if r.hit)
by_dir = {}
for r in rows:
d = by_dir.setdefault(r.direction, {"n": 0, "hit": 0})
d["n"] += 1
d["hit"] += 1 if r.hit else 0
recent = sorted(rows, key=lambda r: (r.date, r.id), reverse=True)[:25]
return {
"closed": n, "open": opens or 0,
"hit_rate": round(hits / n * 100, 1) if n else None,
"avg_ret": round(sum(r.actual_ret for r in rows) / n, 2) if n else None,
"by_direction": {k: {"n": v["n"], "hit_rate": round(v["hit"] / v["n"] * 100, 1)}
for k, v in by_dir.items()},
"recent": [{"date": r.date.isoformat(), "code": r.code, "name": r.name,
"direction": r.direction, "score": r.score, "confidence": r.confidence,
"actual_ret": r.actual_ret, "hit": r.hit, "horizon": r.horizon} for r in recent],
}