Files
stock_cursor_v0/backend/sector_rotation.py
2026-06-14 11:54:45 +08:00

484 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""板块轮动分析 — 追踪板块强弱、资金流向、生命周期。
功能:
1. 板块强弱排名趋势
2. 资金流向分析
3. 板块生命周期判断
4. 龙头股识别
5. 板块联动性分析
"""
import datetime as dt
from typing import Dict, List, Any, Optional
import numpy as np
from sqlalchemy import select, func, and_
from db import get_session
from models import SectorDaily, FundFlowDaily, DailyQuote, StockMetric
def get_sector_trend(days: int = 20, top_n: int = 15) -> Dict[str, Any]:
"""获取板块强弱趋势
Args:
days: 统计天数
top_n: 返回前N个板块
Returns:
板块趋势数据
"""
with get_session() as s:
# 获取最近N天的日期
latest_date = s.execute(select(func.max(SectorDaily.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无板块数据"}
start_date = latest_date - dt.timedelta(days=days)
# 查询板块数据
rows = s.execute(
select(SectorDaily)
.where(SectorDaily.date >= start_date)
.order_by(SectorDaily.date, SectorDaily.name)
).scalars().all()
if not rows:
return {"ok": False, "msg": "数据不足"}
# 按板块聚合
sector_data = {}
for row in rows:
if row.name not in sector_data:
sector_data[row.name] = {
"name": row.name,
"dates": [],
"pcts": [],
"amounts": []
}
sector_data[row.name]["dates"].append(row.date.isoformat())
sector_data[row.name]["pcts"].append(float(row.pct))
sector_data[row.name]["amounts"].append(float(row.amount))
# 计算累计涨跌幅和平均成交额
sector_stats = []
for name, data in sector_data.items():
pcts = data["pcts"]
amounts = data["amounts"]
# 累计收益(复利)
cumulative = 1.0
for p in pcts:
cumulative *= (1 + p / 100)
cumulative_return = (cumulative - 1) * 100
# 近5日、10日、20日收益
returns = {
"5d": sum(pcts[-5:]) if len(pcts) >= 5 else 0,
"10d": sum(pcts[-10:]) if len(pcts) >= 10 else 0,
"20d": cumulative_return
}
# 平均成交额
avg_amount = np.mean(amounts) if amounts else 0
# 波动率(标准差)
volatility = np.std(pcts) if len(pcts) > 1 else 0
sector_stats.append({
"name": name,
"returns": returns,
"avg_amount": round(avg_amount, 2),
"volatility": round(volatility, 2),
"dates": data["dates"],
"pcts": [round(p, 2) for p in pcts]
})
# 按20日收益排序
sector_stats.sort(key=lambda x: x["returns"]["20d"], reverse=True)
return {
"ok": True,
"date": latest_date.isoformat(),
"days": days,
"sectors": sector_stats[:top_n]
}
def analyze_fund_flow(days: int = 5) -> Dict[str, Any]:
"""分析资金流向(板块间流动)
Args:
days: 分析天数
Returns:
资金流向数据(桑基图格式)
"""
with get_session() as s:
latest_date = s.execute(select(func.max(FundFlowDaily.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无资金流数据"}
start_date = latest_date - dt.timedelta(days=days)
# 查询资金流数据
rows = s.execute(
select(FundFlowDaily)
.where(FundFlowDaily.date >= start_date)
.order_by(FundFlowDaily.date, FundFlowDaily.name)
).scalars().all()
if not rows:
return {"ok": False, "msg": "数据不足"}
# 按板块聚合净流入
flow_data = {}
for row in rows:
if row.name not in flow_data:
flow_data[row.name] = 0
flow_data[row.name] += float(row.net)
# 分类:流入 vs 流出
inflows = [(k, v) for k, v in flow_data.items() if v > 0]
outflows = [(k, abs(v)) for k, v in flow_data.items() if v < 0]
inflows.sort(key=lambda x: x[1], reverse=True)
outflows.sort(key=lambda x: x[1], reverse=True)
# 构造桑基图数据
nodes = []
links = []
# 流出节点(左侧)
for i, (name, amount) in enumerate(outflows[:8]):
nodes.append({"name": f"{name}(流出)"})
# 流向"资金池"
links.append({
"source": len(nodes) - 1,
"target": len(outflows[:8]), # 资金池索引
"value": round(amount, 2)
})
# 资金池(中间)
nodes.append({"name": "资金池"})
# 流入节点(右侧)
for i, (name, amount) in enumerate(inflows[:8]):
nodes.append({"name": f"{name}(流入)"})
# 从"资金池"流入
links.append({
"source": len(outflows[:8]), # 资金池索引
"target": len(nodes) - 1,
"value": round(amount, 2)
})
return {
"ok": True,
"date": latest_date.isoformat(),
"days": days,
"total_inflow": round(sum(v for _, v in inflows), 2),
"total_outflow": round(sum(v for _, v in outflows), 2),
"top_inflow": inflows[:8],
"top_outflow": outflows[:8],
"sankey": {
"nodes": nodes,
"links": links
}
}
def analyze_lifecycle(sector_name: str, days: int = 60) -> Dict[str, Any]:
"""分析板块生命周期
Args:
sector_name: 板块名称
days: 分析天数
Returns:
生命周期判断
"""
with get_session() as s:
latest_date = s.execute(select(func.max(SectorDaily.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无数据"}
start_date = latest_date - dt.timedelta(days=days)
rows = s.execute(
select(SectorDaily)
.where(
and_(
SectorDaily.name == sector_name,
SectorDaily.date >= start_date
)
)
.order_by(SectorDaily.date)
).scalars().all()
if len(rows) < 20:
return {"ok": False, "msg": "数据不足"}
# 提取数据
dates = [r.date.isoformat() for r in rows]
pcts = [float(r.pct) for r in rows]
amounts = [float(r.amount) for r in rows]
# 计算指标
# 1. 近期涨跌幅趋势
recent_5 = sum(pcts[-5:])
recent_10 = sum(pcts[-10:])
recent_20 = sum(pcts[-20:])
# 2. 成交额趋势
amount_5 = np.mean(amounts[-5:])
amount_20 = np.mean(amounts[-20:])
amount_change = (amount_5 / amount_20 - 1) * 100 if amount_20 > 0 else 0
# 3. 动量(价格变化加速度)
momentum = recent_5 - recent_10
# 生命周期判断
if recent_20 > 0 and momentum > 0 and amount_change > 20:
phase = "启动期"
description = "板块刚开始上涨,资金流入加速,可能是介入时机"
elif recent_20 > 5 and recent_10 > recent_20 / 2 and amount_change > 0:
phase = "加速期"
description = "板块持续上涨且加速,成交活跃,主升浪阶段"
elif recent_20 > 0 and momentum < 0:
phase = "衰退期"
description = "板块涨幅收窄或开始回调,资金开始流出,注意风险"
elif recent_20 < -5:
phase = "下跌期"
description = "板块持续下跌,避免介入"
else:
phase = "震荡期"
description = "板块横盘整理,方向不明"
return {
"ok": True,
"sector": sector_name,
"phase": phase,
"description": description,
"metrics": {
"return_5d": round(recent_5, 2),
"return_10d": round(recent_10, 2),
"return_20d": round(recent_20, 2),
"momentum": round(momentum, 2),
"amount_change": round(amount_change, 2)
},
"dates": dates,
"pcts": [round(p, 2) for p in pcts]
}
def identify_leaders(sector_name: str, days: int = 20, limit: int = 10) -> Dict[str, Any]:
"""识别板块龙头股
Args:
sector_name: 板块名称
days: 统计天数
limit: 返回数量
Returns:
龙头股列表
"""
# 注意:需要股票-板块映射表,这里简化为通过名称匹配
# 实际应该有 stock_sector 映射表
with get_session() as s:
# 获取最近N天表现最好的股票
latest_date = s.execute(select(func.max(StockMetric.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无股票数据"}
# 查询高涨幅、高成交额股票
rows = s.execute(
select(StockMetric)
.where(
and_(
StockMetric.date == latest_date,
StockMetric.ret20 > 0,
StockMetric.amount > 5 # 成交额 > 5亿
)
)
.order_by(
StockMetric.ret20.desc(),
StockMetric.amount.desc()
)
.limit(limit * 3) # 多取一些,后续筛选
).scalars().all()
# 简化:根据名称关键词匹配板块(实际应该查询映射表)
sector_keywords = {
"半导体": ["芯片", "半导体", "集成电路"],
"新能源": ["新能源", "锂电", "光伏", "储能"],
"医药": ["医药", "生物", "医疗", "药业"],
"白酒": ["", "茅台", "五粮液"],
"军工": ["军工", "航天", "航空", "兵器"],
"AI": ["人工智能", "AI", "算力", "云计算"],
}
keywords = sector_keywords.get(sector_name, [sector_name])
leaders = []
for row in rows:
if any(kw in row.name for kw in keywords):
leaders.append({
"code": row.code,
"name": row.name,
"close": round(row.close, 2),
"pct": round(row.pct, 2),
"ret5": round(row.ret5, 2),
"ret20": round(row.ret20, 2),
"amount": round(row.amount, 2),
"vol_ratio": round(row.vol_ratio, 2)
})
if len(leaders) >= limit:
break
return {
"ok": True,
"sector": sector_name,
"date": latest_date.isoformat(),
"leaders": leaders
}
def analyze_correlation(days: int = 60, top_n: int = 20) -> Dict[str, Any]:
"""板块联动性分析(相关系数矩阵)
Args:
days: 计算天数
top_n: 分析前N个板块
Returns:
相关系数矩阵(热力图数据)
"""
with get_session() as s:
latest_date = s.execute(select(func.max(SectorDaily.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无数据"}
start_date = latest_date - dt.timedelta(days=days)
rows = s.execute(
select(SectorDaily)
.where(SectorDaily.date >= start_date)
.order_by(SectorDaily.date, SectorDaily.name)
).scalars().all()
if not rows:
return {"ok": False, "msg": "数据不足"}
# 按板块聚合涨跌幅
sector_returns = {}
for row in rows:
if row.name not in sector_returns:
sector_returns[row.name] = []
sector_returns[row.name].append(float(row.pct))
# 筛选数据完整的板块
valid_sectors = {k: v for k, v in sector_returns.items() if len(v) >= days * 0.8}
if len(valid_sectors) < 5:
return {"ok": False, "msg": "有效板块不足"}
# 选择前N个板块按最近涨幅
sector_list = []
for name, rets in valid_sectors.items():
recent_return = sum(rets[-min(10, len(rets)):])
sector_list.append((name, recent_return, rets))
sector_list.sort(key=lambda x: x[1], reverse=True)
selected = sector_list[:top_n]
# 计算相关系数矩阵
names = [s[0] for s in selected]
returns_matrix = np.array([s[2][:days] for s in selected])
# 填充短数据用0
max_len = max(len(r) for r in returns_matrix)
padded = []
for r in returns_matrix:
if len(r) < max_len:
r = list(r) + [0] * (max_len - len(r))
padded.append(r[:max_len])
returns_matrix = np.array(padded)
# 计算相关系数
corr_matrix = np.corrcoef(returns_matrix)
# 转换为热力图数据
heatmap_data = []
for i in range(len(names)):
for j in range(len(names)):
heatmap_data.append({
"x": j,
"y": i,
"value": round(float(corr_matrix[i][j]), 3)
})
# 找出高度相关的板块对(相关系数 > 0.7
high_corr = []
for i in range(len(names)):
for j in range(i + 1, len(names)):
corr = float(corr_matrix[i][j])
if corr > 0.7:
high_corr.append({
"sector1": names[i],
"sector2": names[j],
"correlation": round(corr, 3)
})
high_corr.sort(key=lambda x: x["correlation"], reverse=True)
return {
"ok": True,
"days": days,
"sectors": names,
"matrix": corr_matrix.tolist(),
"heatmap": heatmap_data,
"high_correlation": high_corr[:10]
}
def get_rotation_summary() -> Dict[str, Any]:
"""获取板块轮动综合摘要
Returns:
轮动摘要
"""
# 获取最强和最弱板块
trend = get_sector_trend(days=10, top_n=20)
if not trend.get("ok"):
return {"ok": False, "msg": "数据不足"}
sectors = trend["sectors"]
strongest = sectors[:3]
weakest = sectors[-3:]
# 资金流向
flow = analyze_fund_flow(days=5)
summary = {
"ok": True,
"date": trend["date"],
"strongest_sectors": [
{
"name": s["name"],
"return_10d": s["returns"]["10d"]
} for s in strongest
],
"weakest_sectors": [
{
"name": s["name"],
"return_10d": s["returns"]["10d"]
} for s in weakest
],
"fund_flow": {
"top_inflow": flow.get("top_inflow", [])[:3] if flow.get("ok") else [],
"top_outflow": flow.get("top_outflow", [])[:3] if flow.get("ok") else []
}
}
return summary