Files
stock_cursor_v0/backend/ai_chat.py
2026-06-14 11:54:45 +08:00

543 lines
17 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 对话式分析 — 自然语言交互的炒股助手。
功能:
1. 自然语言选股
2. 持仓诊断对话
3. 策略建议
4. 实时问答
5. 上下文记忆(多轮对话)
"""
import json
from typing import List, Dict, Any, Optional
from datetime import datetime, date
import llm
import smart_selector as selector
import portfolio as pf
import ai
import sector_rotation as sector
from db import get_session
from models import StockMetric, Trade
# 会话上下文存储
_SESSIONS = {} # {session_id: {"messages": [], "context": {}}}
def get_or_create_session(session_id: str) -> Dict:
"""获取或创建会话"""
if session_id not in _SESSIONS:
_SESSIONS[session_id] = {
"messages": [],
"context": {},
"created_at": datetime.now().isoformat()
}
return _SESSIONS[session_id]
def chat(session_id: str, user_message: str) -> Dict[str, Any]:
"""AI对话主入口
Args:
session_id: 会话ID
user_message: 用户消息
Returns:
AI回复
"""
if not llm.enabled():
return {
"ok": False,
"msg": "大模型未配置,请在 backend/.env 中配置 LLM_API_KEY",
"text": "抱歉AI对话功能需要配置大模型。您可以\n1. 配置 .env 中的 LLM_API_KEY\n2. 使用其他功能模块(选股、回测、板块分析等)"
}
session = get_or_create_session(session_id)
# 添加用户消息到历史
session["messages"].append({"role": "user", "content": user_message})
# 意图识别 + Function Calling
try:
response = _process_message(session, user_message)
# 添加助手回复到历史
session["messages"].append({"role": "assistant", "content": response["text"]})
# 限制历史长度保留最近20轮
if len(session["messages"]) > 40:
session["messages"] = session["messages"][-40:]
return response
except Exception as e:
return {
"ok": False,
"msg": str(e),
"text": f"处理消息时出错:{str(e)}"
}
def _process_message(session: Dict, message: str) -> Dict[str, Any]:
"""处理用户消息,识别意图并调用相应功能"""
# 构建系统提示
system_prompt = """你是Blackdata股票分析助手擅长A股分析和投资建议。
你可以调用以下功能通过JSON格式返回
1. 选股功能
格式:{"action": "select_stocks", "conditions": {"涨幅": ">10", "量比": ">2", ...}, "description": "..."}
示例:"帮我找近期突破的科技股" -> 识别为选股需求
2. 持仓诊断
格式:{"action": "diagnose_portfolio"}
示例:"我的持仓有什么风险?"
3. 策略建议
格式:{"action": "strategy_advice"}
示例:"当前市场适合什么策略?"
4. 个股分析
格式:{"action": "analyze_stock", "code": "600519"}
示例:"分析一下贵州茅台"
5. 板块分析
格式:{"action": "analyze_sector", "name": "半导体"}
示例:"半导体板块怎么样?"
6. 普通对话
格式:{"action": "chat", "text": "..."}
示例:闲聊、问候等
请根据用户问题,先判断意图,然后:
- 如果需要调用功能返回JSON格式的action
- 如果是普通对话,直接回答
重要:
- 如果用户问题包含"""""筛选"等词,考虑选股功能
- 如果问"我的持仓""风险",调用持仓诊断
- 如果问"策略""怎么操作",给策略建议
- 股票代码格式6位数字
"""
# 构建对话历史
messages = [{"role": "system", "content": system_prompt}]
# 添加最近的对话历史最多10轮
recent_messages = session["messages"][-20:] if len(session["messages"]) > 20 else session["messages"]
messages.extend(recent_messages)
# 调用大模型
try:
response_text = llm.ask_with_messages(messages, temperature=0.7, max_tokens=1500)
# 尝试解析为JSON
action = _parse_action(response_text)
if action:
return _execute_action(action, session)
else:
# 纯文本回复
return {
"ok": True,
"type": "chat",
"text": response_text
}
except Exception as e:
return {
"ok": False,
"text": f"AI处理失败{str(e)}"
}
def _parse_action(text: str) -> Optional[Dict]:
"""解析AI回复中的action"""
try:
# 查找JSON块
if "{" in text and "}" in text:
start = text.find("{")
end = text.rfind("}") + 1
json_str = text[start:end]
return json.loads(json_str)
except:
pass
return None
def _execute_action(action: Dict, session: Dict) -> Dict[str, Any]:
"""执行具体功能"""
action_type = action.get("action")
if action_type == "select_stocks":
return _handle_select_stocks(action, session)
elif action_type == "diagnose_portfolio":
return _handle_diagnose_portfolio(session)
elif action_type == "strategy_advice":
return _handle_strategy_advice(session)
elif action_type == "analyze_stock":
return _handle_analyze_stock(action, session)
elif action_type == "analyze_sector":
return _handle_analyze_sector(action, session)
elif action_type == "chat":
return {
"ok": True,
"type": "chat",
"text": action.get("text", "我是Blackdata AI助手有什么可以帮你")
}
else:
return {
"ok": False,
"text": "抱歉,我不太理解您的问题。您可以问我:\n- 帮我选股\n- 我的持仓怎么样\n- 给我策略建议\n- 分析某个股票或板块"
}
def _handle_select_stocks(action: Dict, session: Dict) -> Dict[str, Any]:
"""处理选股请求"""
# 从自然语言提取条件
description = action.get("description", "")
conditions = action.get("conditions", {})
# 构建选股策略
strategy = selector.Strategy("AI选股", description)
# 将条件转换为选股条件
for field, op_value in conditions.items():
field_map = {
"涨幅": "pct",
"5日涨幅": "ret5",
"20日涨幅": "ret20",
"量比": "vol_ratio",
"成交额": "amount",
"RSI": "rsi14",
"价格": "close"
}
if field in field_map:
actual_field = field_map[field]
# 解析操作符和值
if isinstance(op_value, str):
if op_value.startswith(">"):
op = ">"
val = float(op_value[1:].strip())
elif op_value.startswith("<"):
op = "<"
val = float(op_value[1:].strip())
else:
continue
strategy.add_condition(actual_field, op, val)
# 如果没有条件,添加默认条件
if not strategy.conditions:
strategy.add_condition("ret5", ">", 5)
strategy.add_condition("vol_ratio", ">", 1.5)
# 执行选股
result = selector.run_selector(strategy)
if not result["ok"]:
return {
"ok": False,
"text": f"选股失败:{result.get('msg', '未知错误')}"
}
# 保存选股结果到上下文
session["context"]["last_selection"] = result["results"][:10]
# 格式化回复
stocks = result["results"][:10]
if not stocks:
text = "根据您的条件,暂时没有找到符合的股票。您可以:\n1. 放宽筛选条件\n2. 尝试其他板块\n3. 等待市场出现机会"
else:
text = f"为您找到 {result['count']} 只股票以下是前10只\n\n"
for i, s in enumerate(stocks, 1):
text += f"{i}. {s['name']}{s['code']}\n"
text += f" 现价:{s['close']}元 涨跌:{s['pct']:+.2f}% 5日{s['ret5']:+.2f}%\n"
text += f" 量比:{s['vol_ratio']:.2f} 成交额:{s['amount']:.1f}亿\n\n"
text += "💡 您可以继续问我:\n- 分析某只股票(如\"分析第1只\"\n- 回测这个策略\n- 看看其他板块"
return {
"ok": True,
"type": "select_stocks",
"text": text,
"data": stocks
}
def _handle_diagnose_portfolio(session: Dict) -> Dict[str, Any]:
"""处理持仓诊断"""
try:
portfolio = pf.compute()
if not portfolio["holdings"]:
return {
"ok": True,
"type": "diagnose",
"text": "您当前没有持仓。建议:\n1. 先在「交易日志」录入交易记录\n2. 或者问我\"帮我选股\"来寻找投资机会"
}
summary = portfolio["summary"]
holdings = portfolio["holdings"]
# 分析持仓
total_unrealized = summary["unrealized"]
win_rate = summary["win_rate"]
# 风险诊断
risks = []
# 1. 浮亏检查
losing_positions = [h for h in holdings if h["unrealized"] < 0]
if len(losing_positions) > len(holdings) / 2:
risks.append(f"⚠️ 超过一半的持仓处于浮亏状态({len(losing_positions)}/{len(holdings)}只)")
# 2. 集中度检查
if len(holdings) < 3:
risks.append("⚠️ 持仓过于集中,建议分散投资")
# 3. 胜率检查
if win_rate < 40:
risks.append(f"⚠️ 历史胜率较低({win_rate}%),建议反思选股策略")
# 构建回复
text = f"📊 持仓诊断报告\n\n"
text += f"持仓数量:{summary['positions']}\n"
text += f"持仓市值:{summary['market_value']:.2f}\n"
text += f"浮动盈亏:{total_unrealized:+.2f}\n"
text += f"历史胜率:{win_rate}%\n\n"
if risks:
text += "⚠️ 风险提示:\n"
for risk in risks:
text += f"{risk}\n"
text += "\n"
# 前5大持仓
text += "📈 前5大持仓\n"
for i, h in enumerate(holdings[:5], 1):
pnl_sign = "+" if h["unrealized"] >= 0 else ""
text += f"{i}. {h['name']} {pnl_sign}{h['unrealized_pct']:.2f}% {pnl_sign}{h['unrealized']:.0f}\n"
text += "\n💡 建议:\n"
if risks:
text += "- 考虑止损浮亏较大的股票\n"
text += "- 增加持仓分散度\n"
else:
text += "- 当前持仓状况良好,继续关注\n"
text += "- 定期复盘,总结经验\n"
# 保存到上下文
session["context"]["portfolio"] = holdings
return {
"ok": True,
"type": "diagnose",
"text": text,
"data": portfolio
}
except Exception as e:
return {
"ok": False,
"text": f"持仓诊断失败:{str(e)}"
}
def _handle_strategy_advice(session: Dict) -> Dict[str, Any]:
"""处理策略建议"""
try:
# 获取市场情绪
summary = sector.get_rotation_summary()
if not summary.get("ok"):
return {
"ok": False,
"text": "暂时无法获取市场数据,请稍后再试"
}
strongest = summary.get("strongest_sectors", [])
weakest = summary.get("weakest_sectors", [])
# 构建策略建议
text = "📋 当前市场策略建议\n\n"
text += "🔥 强势板块:\n"
for s in strongest[:3]:
text += f"- {s['name']} {s['return_10d']:+.2f}%\n"
text += "\n"
text += "📉 弱势板块:\n"
for s in weakest[:3]:
text += f"- {s['name']} {s['return_10d']:+.2f}%\n"
text += "\n"
# 策略建议
avg_return = sum(s['return_10d'] for s in strongest[:3]) / 3 if strongest else 0
if avg_return > 10:
text += "💡 策略建议:\n"
text += "- 市场情绪较好,适合进攻型策略\n"
text += "- 可关注强势板块的龙头股\n"
text += "- 设置好止盈点,及时落袋为安\n"
elif avg_return > 0:
text += "💡 策略建议:\n"
text += "- 市场震荡,适合波段操作\n"
text += "- 追踪强势板块,低吸高抛\n"
text += "- 控制仓位,分批建仓\n"
else:
text += "💡 策略建议:\n"
text += "- 市场偏弱,以防守为主\n"
text += "- 减仓观望,等待机会\n"
text += "- 关注超跌板块的反弹机会\n"
text += "\n🎯 具体操作:\n"
text += "- 可以问我\"帮我找[强势板块]的股票\"\n"
text += "- 或\"分析[某个板块]\"\n"
return {
"ok": True,
"type": "strategy",
"text": text,
"data": summary
}
except Exception as e:
return {
"ok": False,
"text": f"策略建议失败:{str(e)}"
}
def _handle_analyze_stock(action: Dict, session: Dict) -> Dict[str, Any]:
"""处理个股分析"""
code = action.get("code", "").strip()
if not code:
# 从上下文中获取
last_selection = session["context"].get("last_selection", [])
if last_selection:
code = last_selection[0]["code"]
else:
return {
"ok": False,
"text": "请指定股票代码,例如\"分析600519\""
}
try:
result = ai.diagnose(code)
if not result["ok"]:
return {
"ok": False,
"text": f"分析失败:{result.get('msg', '未知错误')}"
}
# 格式化回复
text = f"📊 {result['name']}{result['symbol']}AI诊断\n\n"
text += f"综合评分:{result['total']}\n"
text += f"预测方向:{'看多' if result['direction'] == 'up' else ('看空' if result['direction'] == 'down' else '中性')}\n"
text += f"置信度:{result['confidence']}%\n\n"
text += "📈 各维度评分:\n"
for dim, score in result["scores"].items():
text += f"- {dim}{score}\n"
text += f"\n💬 {result['text'][:300]}...\n"
text += "\n💡 完整分析请在「AI分析 → 个股诊断」页面查看"
return {
"ok": True,
"type": "analyze_stock",
"text": text,
"data": result
}
except Exception as e:
return {
"ok": False,
"text": f"分析失败:{str(e)}"
}
def _handle_analyze_sector(action: Dict, session: Dict) -> Dict[str, Any]:
"""处理板块分析"""
sector_name = action.get("name", "").strip()
if not sector_name:
return {
"ok": False,
"text": "请指定板块名称,例如\"分析半导体板块\""
}
try:
result = sector.analyze_lifecycle(sector_name, days=60)
if not result["ok"]:
return {
"ok": False,
"text": f"分析失败:{result.get('msg', '未知错误')}"
}
# 格式化回复
text = f"📊 {result['sector']} 板块分析\n\n"
text += f"生命周期:{result['phase']}\n"
text += f"{result['description']}\n\n"
metrics = result["metrics"]
text += f"📈 近期表现:\n"
text += f"- 5日涨幅{metrics['return_5d']:+.2f}%\n"
text += f"- 20日涨幅{metrics['return_20d']:+.2f}%\n"
text += f"- 成交额变化:{metrics['amount_change']:+.2f}%\n\n"
# 龙头股
leaders = sector.identify_leaders(sector_name, limit=5)
if leaders["ok"] and leaders["leaders"]:
text += "🏆 龙头股:\n"
for i, l in enumerate(leaders["leaders"][:3], 1):
text += f"{i}. {l['name']} {l['ret20']:+.2f}%\n"
text += "\n💡 您可以继续问:\n"
text += f"- 帮我找{sector_name}板块的股票\n"
text += f"- {sector_name}龙头股有哪些\n"
return {
"ok": True,
"type": "analyze_sector",
"text": text,
"data": result
}
except Exception as e:
return {
"ok": False,
"text": f"分析失败:{str(e)}"
}
def clear_session(session_id: str):
"""清空会话"""
if session_id in _SESSIONS:
del _SESSIONS[session_id]
def get_session_history(session_id: str) -> List[Dict]:
"""获取会话历史"""
session = _SESSIONS.get(session_id)
if session:
return session["messages"]
return []