claude强化功能
This commit is contained in:
542
backend/ai_chat.py
Normal file
542
backend/ai_chat.py
Normal file
@@ -0,0 +1,542 @@
|
||||
"""AI 对话式分析 — 自然语言交互的炒股助手。
|
||||
|
||||
功能:
|
||||
1. 自然语言选股
|
||||
2. 持仓诊断对话
|
||||
3. 策略建议
|
||||
4. 实时问答
|
||||
5. 上下文记忆(多轮对话)
|
||||
"""
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, date
|
||||
|
||||
import llm
|
||||
import smart_selector as selector
|
||||
import portfolio as pf
|
||||
import ai
|
||||
import sector_rotation as sector
|
||||
from db import get_session
|
||||
from models import StockMetric, Trade
|
||||
|
||||
# 会话上下文存储
|
||||
_SESSIONS = {} # {session_id: {"messages": [], "context": {}}}
|
||||
|
||||
|
||||
def get_or_create_session(session_id: str) -> Dict:
|
||||
"""获取或创建会话"""
|
||||
if session_id not in _SESSIONS:
|
||||
_SESSIONS[session_id] = {
|
||||
"messages": [],
|
||||
"context": {},
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
return _SESSIONS[session_id]
|
||||
|
||||
|
||||
def chat(session_id: str, user_message: str) -> Dict[str, Any]:
|
||||
"""AI对话主入口
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
AI回复
|
||||
"""
|
||||
if not llm.enabled():
|
||||
return {
|
||||
"ok": False,
|
||||
"msg": "大模型未配置,请在 backend/.env 中配置 LLM_API_KEY",
|
||||
"text": "抱歉,AI对话功能需要配置大模型。您可以:\n1. 配置 .env 中的 LLM_API_KEY\n2. 使用其他功能模块(选股、回测、板块分析等)"
|
||||
}
|
||||
|
||||
session = get_or_create_session(session_id)
|
||||
|
||||
# 添加用户消息到历史
|
||||
session["messages"].append({"role": "user", "content": user_message})
|
||||
|
||||
# 意图识别 + Function Calling
|
||||
try:
|
||||
response = _process_message(session, user_message)
|
||||
|
||||
# 添加助手回复到历史
|
||||
session["messages"].append({"role": "assistant", "content": response["text"]})
|
||||
|
||||
# 限制历史长度(保留最近20轮)
|
||||
if len(session["messages"]) > 40:
|
||||
session["messages"] = session["messages"][-40:]
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"msg": str(e),
|
||||
"text": f"处理消息时出错:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _process_message(session: Dict, message: str) -> Dict[str, Any]:
|
||||
"""处理用户消息,识别意图并调用相应功能"""
|
||||
|
||||
# 构建系统提示
|
||||
system_prompt = """你是Blackdata股票分析助手,擅长A股分析和投资建议。
|
||||
|
||||
你可以调用以下功能(通过JSON格式返回):
|
||||
|
||||
1. 选股功能
|
||||
格式:{"action": "select_stocks", "conditions": {"涨幅": ">10", "量比": ">2", ...}, "description": "..."}
|
||||
示例:"帮我找近期突破的科技股" -> 识别为选股需求
|
||||
|
||||
2. 持仓诊断
|
||||
格式:{"action": "diagnose_portfolio"}
|
||||
示例:"我的持仓有什么风险?"
|
||||
|
||||
3. 策略建议
|
||||
格式:{"action": "strategy_advice"}
|
||||
示例:"当前市场适合什么策略?"
|
||||
|
||||
4. 个股分析
|
||||
格式:{"action": "analyze_stock", "code": "600519"}
|
||||
示例:"分析一下贵州茅台"
|
||||
|
||||
5. 板块分析
|
||||
格式:{"action": "analyze_sector", "name": "半导体"}
|
||||
示例:"半导体板块怎么样?"
|
||||
|
||||
6. 普通对话
|
||||
格式:{"action": "chat", "text": "..."}
|
||||
示例:闲聊、问候等
|
||||
|
||||
请根据用户问题,先判断意图,然后:
|
||||
- 如果需要调用功能,返回JSON格式的action
|
||||
- 如果是普通对话,直接回答
|
||||
|
||||
重要:
|
||||
- 如果用户问题包含"找"、"选"、"筛选"等词,考虑选股功能
|
||||
- 如果问"我的持仓"、"风险",调用持仓诊断
|
||||
- 如果问"策略"、"怎么操作",给策略建议
|
||||
- 股票代码格式:6位数字
|
||||
"""
|
||||
|
||||
# 构建对话历史
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# 添加最近的对话历史(最多10轮)
|
||||
recent_messages = session["messages"][-20:] if len(session["messages"]) > 20 else session["messages"]
|
||||
messages.extend(recent_messages)
|
||||
|
||||
# 调用大模型
|
||||
try:
|
||||
response_text = llm.ask_with_messages(messages, temperature=0.7, max_tokens=1500)
|
||||
|
||||
# 尝试解析为JSON
|
||||
action = _parse_action(response_text)
|
||||
|
||||
if action:
|
||||
return _execute_action(action, session)
|
||||
else:
|
||||
# 纯文本回复
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "chat",
|
||||
"text": response_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"AI处理失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _parse_action(text: str) -> Optional[Dict]:
|
||||
"""解析AI回复中的action"""
|
||||
try:
|
||||
# 查找JSON块
|
||||
if "{" in text and "}" in text:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
json_str = text[start:end]
|
||||
return json.loads(json_str)
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _execute_action(action: Dict, session: Dict) -> Dict[str, Any]:
|
||||
"""执行具体功能"""
|
||||
|
||||
action_type = action.get("action")
|
||||
|
||||
if action_type == "select_stocks":
|
||||
return _handle_select_stocks(action, session)
|
||||
|
||||
elif action_type == "diagnose_portfolio":
|
||||
return _handle_diagnose_portfolio(session)
|
||||
|
||||
elif action_type == "strategy_advice":
|
||||
return _handle_strategy_advice(session)
|
||||
|
||||
elif action_type == "analyze_stock":
|
||||
return _handle_analyze_stock(action, session)
|
||||
|
||||
elif action_type == "analyze_sector":
|
||||
return _handle_analyze_sector(action, session)
|
||||
|
||||
elif action_type == "chat":
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "chat",
|
||||
"text": action.get("text", "我是Blackdata AI助手,有什么可以帮你?")
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": "抱歉,我不太理解您的问题。您可以问我:\n- 帮我选股\n- 我的持仓怎么样\n- 给我策略建议\n- 分析某个股票或板块"
|
||||
}
|
||||
|
||||
|
||||
def _handle_select_stocks(action: Dict, session: Dict) -> Dict[str, Any]:
|
||||
"""处理选股请求"""
|
||||
|
||||
# 从自然语言提取条件
|
||||
description = action.get("description", "")
|
||||
conditions = action.get("conditions", {})
|
||||
|
||||
# 构建选股策略
|
||||
strategy = selector.Strategy("AI选股", description)
|
||||
|
||||
# 将条件转换为选股条件
|
||||
for field, op_value in conditions.items():
|
||||
field_map = {
|
||||
"涨幅": "pct",
|
||||
"5日涨幅": "ret5",
|
||||
"20日涨幅": "ret20",
|
||||
"量比": "vol_ratio",
|
||||
"成交额": "amount",
|
||||
"RSI": "rsi14",
|
||||
"价格": "close"
|
||||
}
|
||||
|
||||
if field in field_map:
|
||||
actual_field = field_map[field]
|
||||
|
||||
# 解析操作符和值
|
||||
if isinstance(op_value, str):
|
||||
if op_value.startswith(">"):
|
||||
op = ">"
|
||||
val = float(op_value[1:].strip())
|
||||
elif op_value.startswith("<"):
|
||||
op = "<"
|
||||
val = float(op_value[1:].strip())
|
||||
else:
|
||||
continue
|
||||
|
||||
strategy.add_condition(actual_field, op, val)
|
||||
|
||||
# 如果没有条件,添加默认条件
|
||||
if not strategy.conditions:
|
||||
strategy.add_condition("ret5", ">", 5)
|
||||
strategy.add_condition("vol_ratio", ">", 1.5)
|
||||
|
||||
# 执行选股
|
||||
result = selector.run_selector(strategy)
|
||||
|
||||
if not result["ok"]:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"选股失败:{result.get('msg', '未知错误')}"
|
||||
}
|
||||
|
||||
# 保存选股结果到上下文
|
||||
session["context"]["last_selection"] = result["results"][:10]
|
||||
|
||||
# 格式化回复
|
||||
stocks = result["results"][:10]
|
||||
if not stocks:
|
||||
text = "根据您的条件,暂时没有找到符合的股票。您可以:\n1. 放宽筛选条件\n2. 尝试其他板块\n3. 等待市场出现机会"
|
||||
else:
|
||||
text = f"为您找到 {result['count']} 只股票,以下是前10只:\n\n"
|
||||
for i, s in enumerate(stocks, 1):
|
||||
text += f"{i}. {s['name']}({s['code']})\n"
|
||||
text += f" 现价:{s['close']}元 涨跌:{s['pct']:+.2f}% 5日:{s['ret5']:+.2f}%\n"
|
||||
text += f" 量比:{s['vol_ratio']:.2f} 成交额:{s['amount']:.1f}亿\n\n"
|
||||
|
||||
text += "💡 您可以继续问我:\n- 分析某只股票(如\"分析第1只\")\n- 回测这个策略\n- 看看其他板块"
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "select_stocks",
|
||||
"text": text,
|
||||
"data": stocks
|
||||
}
|
||||
|
||||
|
||||
def _handle_diagnose_portfolio(session: Dict) -> Dict[str, Any]:
|
||||
"""处理持仓诊断"""
|
||||
|
||||
try:
|
||||
portfolio = pf.compute()
|
||||
|
||||
if not portfolio["holdings"]:
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "diagnose",
|
||||
"text": "您当前没有持仓。建议:\n1. 先在「交易日志」录入交易记录\n2. 或者问我\"帮我选股\"来寻找投资机会"
|
||||
}
|
||||
|
||||
summary = portfolio["summary"]
|
||||
holdings = portfolio["holdings"]
|
||||
|
||||
# 分析持仓
|
||||
total_unrealized = summary["unrealized"]
|
||||
win_rate = summary["win_rate"]
|
||||
|
||||
# 风险诊断
|
||||
risks = []
|
||||
|
||||
# 1. 浮亏检查
|
||||
losing_positions = [h for h in holdings if h["unrealized"] < 0]
|
||||
if len(losing_positions) > len(holdings) / 2:
|
||||
risks.append(f"⚠️ 超过一半的持仓处于浮亏状态({len(losing_positions)}/{len(holdings)}只)")
|
||||
|
||||
# 2. 集中度检查
|
||||
if len(holdings) < 3:
|
||||
risks.append("⚠️ 持仓过于集中,建议分散投资")
|
||||
|
||||
# 3. 胜率检查
|
||||
if win_rate < 40:
|
||||
risks.append(f"⚠️ 历史胜率较低({win_rate}%),建议反思选股策略")
|
||||
|
||||
# 构建回复
|
||||
text = f"📊 持仓诊断报告\n\n"
|
||||
text += f"持仓数量:{summary['positions']} 只\n"
|
||||
text += f"持仓市值:{summary['market_value']:.2f} 元\n"
|
||||
text += f"浮动盈亏:{total_unrealized:+.2f} 元\n"
|
||||
text += f"历史胜率:{win_rate}%\n\n"
|
||||
|
||||
if risks:
|
||||
text += "⚠️ 风险提示:\n"
|
||||
for risk in risks:
|
||||
text += f"{risk}\n"
|
||||
text += "\n"
|
||||
|
||||
# 前5大持仓
|
||||
text += "📈 前5大持仓:\n"
|
||||
for i, h in enumerate(holdings[:5], 1):
|
||||
pnl_sign = "+" if h["unrealized"] >= 0 else ""
|
||||
text += f"{i}. {h['name']} {pnl_sign}{h['unrealized_pct']:.2f}% {pnl_sign}{h['unrealized']:.0f}元\n"
|
||||
|
||||
text += "\n💡 建议:\n"
|
||||
if risks:
|
||||
text += "- 考虑止损浮亏较大的股票\n"
|
||||
text += "- 增加持仓分散度\n"
|
||||
else:
|
||||
text += "- 当前持仓状况良好,继续关注\n"
|
||||
text += "- 定期复盘,总结经验\n"
|
||||
|
||||
# 保存到上下文
|
||||
session["context"]["portfolio"] = holdings
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "diagnose",
|
||||
"text": text,
|
||||
"data": portfolio
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"持仓诊断失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _handle_strategy_advice(session: Dict) -> Dict[str, Any]:
|
||||
"""处理策略建议"""
|
||||
|
||||
try:
|
||||
# 获取市场情绪
|
||||
summary = sector.get_rotation_summary()
|
||||
|
||||
if not summary.get("ok"):
|
||||
return {
|
||||
"ok": False,
|
||||
"text": "暂时无法获取市场数据,请稍后再试"
|
||||
}
|
||||
|
||||
strongest = summary.get("strongest_sectors", [])
|
||||
weakest = summary.get("weakest_sectors", [])
|
||||
|
||||
# 构建策略建议
|
||||
text = "📋 当前市场策略建议\n\n"
|
||||
|
||||
text += "🔥 强势板块:\n"
|
||||
for s in strongest[:3]:
|
||||
text += f"- {s['name']} {s['return_10d']:+.2f}%\n"
|
||||
text += "\n"
|
||||
|
||||
text += "📉 弱势板块:\n"
|
||||
for s in weakest[:3]:
|
||||
text += f"- {s['name']} {s['return_10d']:+.2f}%\n"
|
||||
text += "\n"
|
||||
|
||||
# 策略建议
|
||||
avg_return = sum(s['return_10d'] for s in strongest[:3]) / 3 if strongest else 0
|
||||
|
||||
if avg_return > 10:
|
||||
text += "💡 策略建议:\n"
|
||||
text += "- 市场情绪较好,适合进攻型策略\n"
|
||||
text += "- 可关注强势板块的龙头股\n"
|
||||
text += "- 设置好止盈点,及时落袋为安\n"
|
||||
elif avg_return > 0:
|
||||
text += "💡 策略建议:\n"
|
||||
text += "- 市场震荡,适合波段操作\n"
|
||||
text += "- 追踪强势板块,低吸高抛\n"
|
||||
text += "- 控制仓位,分批建仓\n"
|
||||
else:
|
||||
text += "💡 策略建议:\n"
|
||||
text += "- 市场偏弱,以防守为主\n"
|
||||
text += "- 减仓观望,等待机会\n"
|
||||
text += "- 关注超跌板块的反弹机会\n"
|
||||
|
||||
text += "\n🎯 具体操作:\n"
|
||||
text += "- 可以问我\"帮我找[强势板块]的股票\"\n"
|
||||
text += "- 或\"分析[某个板块]\"\n"
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "strategy",
|
||||
"text": text,
|
||||
"data": summary
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"策略建议失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _handle_analyze_stock(action: Dict, session: Dict) -> Dict[str, Any]:
|
||||
"""处理个股分析"""
|
||||
|
||||
code = action.get("code", "").strip()
|
||||
|
||||
if not code:
|
||||
# 从上下文中获取
|
||||
last_selection = session["context"].get("last_selection", [])
|
||||
if last_selection:
|
||||
code = last_selection[0]["code"]
|
||||
else:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": "请指定股票代码,例如\"分析600519\""
|
||||
}
|
||||
|
||||
try:
|
||||
result = ai.diagnose(code)
|
||||
|
||||
if not result["ok"]:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"分析失败:{result.get('msg', '未知错误')}"
|
||||
}
|
||||
|
||||
# 格式化回复
|
||||
text = f"📊 {result['name']}({result['symbol']})AI诊断\n\n"
|
||||
text += f"综合评分:{result['total']}分\n"
|
||||
text += f"预测方向:{'看多' if result['direction'] == 'up' else ('看空' if result['direction'] == 'down' else '中性')}\n"
|
||||
text += f"置信度:{result['confidence']}%\n\n"
|
||||
|
||||
text += "📈 各维度评分:\n"
|
||||
for dim, score in result["scores"].items():
|
||||
text += f"- {dim}:{score}分\n"
|
||||
|
||||
text += f"\n💬 {result['text'][:300]}...\n"
|
||||
text += "\n💡 完整分析请在「AI分析 → 个股诊断」页面查看"
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "analyze_stock",
|
||||
"text": text,
|
||||
"data": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"分析失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _handle_analyze_sector(action: Dict, session: Dict) -> Dict[str, Any]:
|
||||
"""处理板块分析"""
|
||||
|
||||
sector_name = action.get("name", "").strip()
|
||||
|
||||
if not sector_name:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": "请指定板块名称,例如\"分析半导体板块\""
|
||||
}
|
||||
|
||||
try:
|
||||
result = sector.analyze_lifecycle(sector_name, days=60)
|
||||
|
||||
if not result["ok"]:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"分析失败:{result.get('msg', '未知错误')}"
|
||||
}
|
||||
|
||||
# 格式化回复
|
||||
text = f"📊 {result['sector']} 板块分析\n\n"
|
||||
text += f"生命周期:{result['phase']}\n"
|
||||
text += f"{result['description']}\n\n"
|
||||
|
||||
metrics = result["metrics"]
|
||||
text += f"📈 近期表现:\n"
|
||||
text += f"- 5日涨幅:{metrics['return_5d']:+.2f}%\n"
|
||||
text += f"- 20日涨幅:{metrics['return_20d']:+.2f}%\n"
|
||||
text += f"- 成交额变化:{metrics['amount_change']:+.2f}%\n\n"
|
||||
|
||||
# 龙头股
|
||||
leaders = sector.identify_leaders(sector_name, limit=5)
|
||||
if leaders["ok"] and leaders["leaders"]:
|
||||
text += "🏆 龙头股:\n"
|
||||
for i, l in enumerate(leaders["leaders"][:3], 1):
|
||||
text += f"{i}. {l['name']} {l['ret20']:+.2f}%\n"
|
||||
|
||||
text += "\n💡 您可以继续问:\n"
|
||||
text += f"- 帮我找{sector_name}板块的股票\n"
|
||||
text += f"- {sector_name}龙头股有哪些\n"
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "analyze_sector",
|
||||
"text": text,
|
||||
"data": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"text": f"分析失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def clear_session(session_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in _SESSIONS:
|
||||
del _SESSIONS[session_id]
|
||||
|
||||
|
||||
def get_session_history(session_id: str) -> List[Dict]:
|
||||
"""获取会话历史"""
|
||||
session = _SESSIONS.get(session_id)
|
||||
if session:
|
||||
return session["messages"]
|
||||
return []
|
||||
Reference in New Issue
Block a user