From 964c17c2000ee18d5f9e727e373ce24fe89b5fff Mon Sep 17 00:00:00 2001 From: Lukas Date: Mon, 15 Jun 2026 01:26:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=9F=E8=83=BD=E7=BB=86=E8=8A=82=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 336 ++++++----- backend/CHECKLIST.md | 237 ++++++++ backend/ENV_CONFIG.md | 118 ++++ backend/UPGRADE_GUIDE.md | 287 +++++++++ backend/ai.py | 170 ++++++ backend/akshare_service.py | 142 ++++- backend/auth.py | 88 +++ backend/cli.py | 4 + backend/config.py | 16 + backend/data_manager.py | 398 +++++++++++++ backend/exceptions.py | 73 +++ backend/ingest.py | 22 +- backend/init_auth.py | 21 + backend/install.sh | 106 ++++ backend/limit_analysis.py | 235 +++++++- backend/main.py | 518 ++++++++++++++++- backend/models.py | 82 ++- backend/paper_trading.py | 222 +++++++ backend/position_cost.py | 347 +++++++++++ backend/redis_cache.py | 88 +++ backend/requirements.txt | 4 + backend/scheduler.py | 16 +- backend/test_core_features.py | 179 ++++++ backend/trade_calendar.py | 338 +++++++++++ backend/watchlist_manager.py | 345 +++++++++++ prototype/app.js | 832 ++++++++++++++++++++++++++- prototype/index.html | 6 +- prototype/paper-trading.js | 122 ++++ prototype/style.css | 37 ++ 三大核心功能实现总结.md | 464 +++++++++++++++ 功能实现/1_自选股分组管理使用说明.md | 435 ++++++++++++++ 功能实现/2_持仓成本可视化使用说明.md | 528 +++++++++++++++++ 实施完成报告.md | 384 +++++++++++++ 33 files changed, 6990 insertions(+), 210 deletions(-) create mode 100644 backend/CHECKLIST.md create mode 100644 backend/ENV_CONFIG.md create mode 100644 backend/UPGRADE_GUIDE.md create mode 100644 backend/auth.py create mode 100644 backend/data_manager.py create mode 100644 backend/exceptions.py create mode 100644 backend/init_auth.py create mode 100644 backend/install.sh create mode 100644 backend/paper_trading.py create mode 100644 backend/position_cost.py create mode 100644 backend/redis_cache.py create mode 100644 backend/test_core_features.py create mode 100644 backend/trade_calendar.py create mode 100644 backend/watchlist_manager.py create mode 100644 prototype/paper-trading.js create mode 100644 三大核心功能实现总结.md create mode 100644 功能实现/1_自选股分组管理使用说明.md create mode 100644 功能实现/2_持仓成本可视化使用说明.md create mode 100644 实施完成报告.md diff --git a/README.md b/README.md index 08b2187..06476f6 100644 --- a/README.md +++ b/README.md @@ -1,123 +1,141 @@ # Blackdata StockTerminal -个人/小团队 A 股分析·复盘·智能专业分析系统。后端提供行情、回测、AI 诊断与定时任务;前端为纯 HTML + ECharts 原型界面,由 FastAPI 统一托管。 +个人 / 小团队 A 股分析·复盘·智能辅助系统。 -## 功能概览 +后端基于 FastAPI,提供行情、回测、AI 诊断与定时任务;前端为纯 HTML + ECharts 原型,由 FastAPI 统一托管,浏览器直接访问即可使用。 -| 模块 | 能力 | +--- + +## 功能模块 + +| 模块 | 说明 | |---|---| -| **大盘行情** | 三大指数、情绪温度计、板块云图、热股榜、龙虎榜、涨跌停统计 | -| **盘中监控** | 异动雷达(快速拉升/放量突破/涨停打开/连板追踪/大单异动)、实时扫描与推送 | -| **自选股** | 自选列表、分组管理、内置 8 种策略选股、多因子条件过滤 | -| **智能选股** | 可视化条件组合器、选股策略保存/分享、选股结果回测验证、条件预警集成 | -| **复盘中心** | 每日复盘(板块/资金/龙虎榜)、AI 七段式日报、个股 K 线回放(MA 买卖点标注) | -| **策略回测** | MA 交叉/多因子策略回测、参数优化网格搜索、策略对比(并排净值曲线)、交易明细导出 | -| **板块轮动** | 板块强弱趋势、资金流向桑基图、龙头股识别、生命周期判断、板块联动性分析 | -| **AI 分析** | 个股诊断(6 维证据链)、AI 对话式分析、信号历史胜率、预测留痕与准确率核验 | -| **组合交易** | 持仓 P&L、资金曲线、交易日志(理由/情绪标签)、持仓归因分析(选股/择时/运气分解) | -| **智能预警** | 价格/涨跌幅/量能/技术信号规则、选股策略预警、多通道推送(邮件/微信/企微)、触发记录 | -| **资讯中心** | 财经快讯、AI 情绪判断与摘要、自选股相关资讯、关联个股分析 | -| **社区情绪** | 热帖采集(东方财富/雪球)、情绪指数计算、热议股票排行、关键词云图、情绪与股价相关性 | -| **事件驱动** | 财报发布前后规律、高管增减持跟踪、限售解禁影响、行业政策事件库、事件驱动选股 | -| **财报解读** | 关键指标趋势、AI 财报摘要、同行对比、财报异常预警、发布日历、排行榜 | -| **涨跌停分析** | 涨停/跌停股票追踪、连板股监控、炸板率统计、涨停敢死队排行 | -| **数据中台** | 数据入库状态、任务日志、全市场历史回填、定时调度监控 | +| 大盘行情 | 三大指数、情绪温度计、板块云图、热股榜、龙虎榜、涨跌停统计 | +| 盘中监控 | 异动雷达(快速拉升 / 放量突破 / 涨停打开 / 连板追踪 / 大单异动) | +| 自选股 | 自选列表与分组、8 种内置策略选股、多因子条件过滤 | +| 智能选股 | 可视化条件组合器、策略保存 / 分享、选股回测验证、条件预警 | +| 复盘中心 | 每日复盘(板块 / 资金 / 龙虎榜)、AI 七段式日报、个股 K 线回放 | +| 策略回测 | MA 交叉 / 多因子回测、参数优化网格搜索、策略对比、交易明细导出 | +| 板块轮动 | 板块强弱趋势、资金流向桑基图、龙头股识别、生命周期判断 | +| AI 分析 | 个股 6 维诊断、AI 对话式分析、信号胜率历史、预测准确率核验 | +| 组合交易 | 持仓 P&L、资金曲线、交易日志(理由 / 情绪标签)、持仓归因分析 | +| 智能预警 | 价格 / 量能 / 技术信号规则预警、多通道推送(邮件 / 微信 / 企微) | +| 资讯中心 | 财经快讯、AI 情绪摘要、自选股关联资讯 | +| 社区情绪 | 东方财富 / 雪球热帖采集、情绪指数、热议股排行、情绪与股价相关性 | +| 事件驱动 | 财报前后规律、高管增减持、限售解禁、政策事件库、事件驱动选股 | +| 财报解读 | 关键指标趋势、AI 财报摘要、同行对比、异常预警、发布日历 | +| 涨跌停分析 | 涨停 / 跌停追踪、连板监控、炸板率统计、涨停敢死队排行 | +| 数据中台 | 数据入库状态、任务日志、全市场历史回填、定时调度监控 | -更完整的架构说明见 [架构总结.md](./架构总结.md)。 +--- ## 技术栈 -- **前端**:HTML + CSS + 原生 JS,ECharts 5(CDN) -- **后端**:Python 3.12 · FastAPI · uvicorn -- **数据库**:PostgreSQL · SQLAlchemy 2.0 -- **数据源**:AkShare(行情/情绪/资讯),Sina 实时报价 -- **调度**:APScheduler -- **AI**:OpenAI 兼容接口(DeepSeek / 通义 / Kimi 等),无 Key 时规则降级 +| 层 | 技术 | +|---|---| +| 前端 | HTML + CSS + 原生 JS,ECharts 5(CDN) | +| 后端 | Python 3.12 · FastAPI · uvicorn | +| 数据库 | PostgreSQL 14+ · SQLAlchemy 2.0 · psycopg2 | +| 缓存 | Redis 5+(可选,自动降级到内存缓存) | +| 数据源 | AkShare(行情 / 情绪 / 资讯)· Sina 实时报价 | +| 调度 | APScheduler | +| AI | OpenAI 兼容接口(DeepSeek / 通义 / Kimi 等),无 Key 时规则降级 | +| 鉴权 | JWT Token + API Key 双模式 | + +--- ## 项目结构 ``` -stock_cs/ -├── backend/ # FastAPI 后端 -│ ├── main.py # API 入口 + 路由定义 -│ ├── cli.py # 建库/入库命令行工具 -│ ├── models.py # SQLAlchemy 数据模型 -│ ├── db.py # 数据库连接管理 -│ ├── config.py # 配置项 -│ ├── scheduler.py # APScheduler 定时任务 -│ ├── akshare_service.py # 数据源接口封装 -│ ├── ai.py # AI 分析核心 -│ ├── ai_chat.py # AI 对话式分析 -│ ├── llm.py # 大模型调用封装 -│ ├── backtest.py # 基础回测引擎 -│ ├── backtest_advanced.py # 增强回测(多因子/参数优化/策略对比) -│ ├── signals.py # 信号胜率统计 -│ ├── report.py # AI 复盘日报生成 -│ ├── portfolio.py # 组合与持仓计算 -│ ├── attribution_analysis.py # 持仓归因分析 -│ ├── alerts.py # 智能预警核心 -│ ├── notifier.py # 多通道推送 -│ ├── intraday_radar.py # 盘中异动雷达 -│ ├── sector_rotation.py # 板块轮动分析 -│ ├── smart_selector.py # 智能选股增强 -│ ├── sentiment_monitor.py # 社区情绪监控 -│ ├── event_driven.py # 事件驱动策略 -│ ├── financial_analysis.py # 财报深度解读 -│ ├── limit_analysis.py # 涨跌停分析 -│ ├── .env.example # 环境变量模板 -│ └── requirements.txt # Python 依赖 -├── prototype/ # 前端原型(HTML + JS + CSS) -├── 架构总结.md # 架构设计文档 -├── 功能架构.md # 功能模块详解 -├── 待优化.md # 已知问题与优化方向 -└── 功能扩展.md # 扩展功能建议 +. +├── backend/ +│ ├── main.py # FastAPI 入口 + 路由 +│ ├── cli.py # 建库 / 入库命令行工具 +│ ├── models.py # SQLAlchemy 数据模型 +│ ├── db.py # 数据库连接 +│ ├── config.py # 全局配置(读 .env) +│ ├── scheduler.py # APScheduler 定时任务 +│ ├── auth.py # JWT / API Key 鉴权 +│ ├── redis_cache.py # Redis 缓存层 +│ ├── exceptions.py # 统一异常处理 +│ ├── akshare_service.py # 数据源封装 +│ ├── ingest.py # 数据入库逻辑 +│ ├── data_manager.py # 数据管理 +│ ├── trade_calendar.py # 交易日历 +│ ├── llm.py # 大模型调用封装 +│ ├── ai.py # AI 个股诊断 +│ ├── ai_chat.py # AI 对话式分析 +│ ├── rag.py # RAG 知识增强 +│ ├── backtest.py # 基础回测引擎 +│ ├── backtest_advanced.py # 增强回测(多因子 / 参数优化 / 对比) +│ ├── signals.py # 信号胜率统计 +│ ├── report.py # AI 复盘日报生成 +│ ├── portfolio.py # 组合与持仓计算 +│ ├── position_cost.py # 持仓成本追踪 +│ ├── attribution_analysis.py # 持仓归因分析 +│ ├── alerts.py # 智能预警 +│ ├── notifier.py # 多通道推送 +│ ├── intraday_radar.py # 盘中异动雷达 +│ ├── sector_rotation.py # 板块轮动分析 +│ ├── smart_selector.py # 智能选股 +│ ├── sentiment_monitor.py # 社区情绪监控 +│ ├── event_driven.py # 事件驱动策略 +│ ├── financial_analysis.py # 财报深度解读 +│ ├── limit_analysis.py # 涨跌停分析 +│ ├── watchlist_manager.py # 自选股管理 +│ ├── .env.example # 环境变量模板 +│ └── requirements.txt +├── prototype/ # 前端静态文件(HTML / JS / CSS) +│ ├── index.html +│ ├── app.js +│ ├── style.css +│ └── *.js # 各功能模块 JS +├── 架构总结.md +├── 功能架构.md +├── 待优化.md +└── 功能扩展.md ``` +--- + ## 环境要求 - Python 3.12+ -- PostgreSQL 14+(本地或远程均可) -- 可选:大模型 API Key、推送渠道密钥(见下方配置) +- PostgreSQL 14+ +- Redis 5+(可选,不启动时自动降级内存缓存) +- 大模型 API Key(可选,不配置时使用规则引擎降级) -## 快速开始 +--- -以下命令以 **WSL(Linux)** 为例。项目在 Windows 盘时,路径一般为 `/mnt/e/project/stock_cs_v1`;若在 WSL 家目录,则替换为实际路径即可。 +## 快速开始(WSL / Linux) -### 1. 安装 PostgreSQL(WSL,首次) +### 1. 安装服务(首次) ```bash sudo apt update -sudo apt install -y postgresql postgresql-contrib +sudo apt install -y postgresql postgresql-contrib redis-server sudo service postgresql start +sudo service redis-server start -# 设置 postgres 用户密码(与 backend/.env 中 PG_PASSWORD 一致) +# 设置 postgres 密码(与后续 .env 中 PG_PASSWORD 保持一致) sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';" ``` -WSL 每次重启后若数据库未自动运行,需先执行: +WSL 重启后若服务未自动运行: ```bash -sudo service postgresql start +sudo service postgresql start && sudo service redis-server start ``` ### 2. 安装 Python 依赖(首次) ```bash -cd /mnt/e/project/stock_cs_v1/backend # 按实际路径修改 +cd backend python3 -m venv .venv source .venv/bin/activate pip install -r requirements.txt ``` -**Windows 原生(非 WSL)** 激活虚拟环境: - -```powershell -cd backend -python -m venv .venv -.venv\Scripts\activate -pip install -r requirements.txt -``` - ### 3. 配置环境变量 ```bash @@ -125,141 +143,151 @@ cd backend cp .env.example .env ``` -编辑 `backend/.env`,至少确认 PostgreSQL 连接信息(PostgreSQL 装在 WSL 内时使用 `localhost`): +编辑 `backend/.env`: ```env +# 数据库 PG_USER=postgres PG_PASSWORD=your_password PG_HOST=localhost PG_PORT=5432 PG_DB=stock_cs + +# Redis(可选) +REDIS_HOST=localhost +REDIS_PORT=6379 + +# 鉴权(生产环境务必修改 SECRET_KEY) +SECRET_KEY=your-secret-key-change-in-production +DEFAULT_ADMIN_USERNAME=admin +DEFAULT_ADMIN_PASSWORD=admin123 + +# 大模型(可选,不填则规则降级) +LLM_API_KEY= +LLM_BASE_URL=https://api.deepseek.com/v1 +LLM_MODEL=deepseek-chat ``` -也可通过环境变量 `PG_USER` / `PG_PASSWORD` / `PG_HOST` / `PG_PORT` / `PG_DB` 设置,无需改文件。 +生成安全的 SECRET_KEY: -可选:填入 `LLM_API_KEY` 启用大模型分析;填入 SMTP / Server酱 / 企业微信 / PushPlus 启用推送。 +```bash +python3 -c "import secrets; print(secrets.token_urlsafe(32))" +``` -### 4. 初始化数据库并入库(首次) +完整配置说明见 [backend/ENV_CONFIG.md](backend/ENV_CONFIG.md)。 + +### 4. 初始化数据库(首次) ```bash cd backend -source .venv/bin/activate # WSL / Linux +source .venv/bin/activate -# 建库建表 +# 建表 + 创建管理员账号 python cli.py init -# 抓取当日板块/资金流/情绪/龙虎榜等快照 +# 抓取当日行情快照(板块 / 资金 / 情绪 / 龙虎榜) python cli.py ingest # 全市场日线历史入库(默认 250 交易日,耗时较长,可选) python cli.py ingest_all python cli.py ingest_all 500 # 指定天数 -``` -指定股票入库: - -```bash +# 指定股票入库 python cli.py ingest 600519 000001 ``` ### 5. 启动服务 -**日常启动(WSL):** - ```bash -sudo service postgresql start -cd /mnt/e/project/stock_cs_v1/backend # 按实际路径修改 +sudo service postgresql start && sudo service redis-server start +cd backend source .venv/bin/activate python main.py ``` -一键命令(已配置好后): - -```bash -sudo service postgresql start && cd /mnt/e/project/stock_cs_v1/backend && source .venv/bin/activate && python main.py -``` - 浏览器访问:**http://127.0.0.1:8000**(WSL2 下 Windows 浏览器可直接访问) -健康检查:`GET /api/health` +健康检查:`GET /api/health`(返回 Redis、AkShare、鉴权状态) -### 常见问题(WSL) - -| 现象 | 处理 | -|---|---| -| `connection refused` | 执行 `sudo service postgresql start` | -| `password authentication failed` | 检查 `.env` 中 `PG_PASSWORD` 是否与 `ALTER USER` 设置一致 | -| `python: command not found` | 使用 `python3` | -| 每次新开终端 | 先 `source .venv/bin/activate` 再运行命令 | +--- ## 定时任务 -服务启动后,APScheduler 会在工作日自动执行(可在 `config.py` 或环境变量中调整时间): +服务启动后,APScheduler 在交易日自动执行: | 任务 | 默认时间 | 说明 | |---|---|---| -| `daily_ingest` | 15:35 | 收盘后增量入库(板块/资金/情绪/龙虎榜/个股行情) | -| `alert_check` | 每 60 秒 | 实时报价预警核查(价格/涨跌幅/量能等规则) | -| `intraday_scan` | 交易时段每 5 分钟 | 盘中异动扫描(快速拉升/放量突破/涨停打开/连板追踪) | -| `daily_report` | 15:45 | 生成 AI 复盘日报并推送(需配置大模型 API) | -| `verify_pred` | 15:50 | 核验到期 AI 预测,更新准确率统计 | -| `signal_stats` | 周六 09:00 | 全市场信号胜率回测(MACD 金叉/突破等技术信号) | -| `selector_check` | 15:40 | 选股策略预警检查,符合条件时推送 | +| `daily_ingest` | 15:35 | 收盘后增量入库 | +| `alert_check` | 每 60 秒 | 实时价格 / 量能预警核查 | +| `intraday_scan` | 交易时段每 5 分钟 | 盘中异动扫描 | +| `daily_report` | 15:45 | 生成 AI 复盘日报并推送 | +| `verify_pred` | 15:50 | 核验到期 AI 预测,更新准确率 | +| `signal_stats` | 周六 09:00 | 全市场信号胜率回测 | +| `selector_check` | 15:40 | 选股策略预警检查 | +| `calendar_alerts` | 08:30 | 推送持仓股除权 / 解禁 / 财报等日历事件提醒 | + +时间可通过环境变量 `INGEST_HOUR` / `INGEST_MINUTE` 调整。 + +--- ## 推送渠道 -在 `.env` 中配置任意一种即可启用,互不依赖: +在 `.env` 中配置任意一种即可,互不依赖: | 渠道 | 配置项 | |---|---| -| SMTP 邮件 | `SMTP_HOST` / `SMTP_PORT` / `SMTP_USER` / `SMTP_PASSWORD` / `SMTP_TO` | -| Server酱 | `SERVERCHAN_KEY` | -| 企业微信 | `WECOM_WEBHOOK` | -| PushPlus | `PUSHPLUS_TOKEN` | +| SMTP 邮件 | `SMTP_HOST` · `SMTP_PORT` · `SMTP_USER` · `SMTP_PASSWORD` · `SMTP_TO` | +| Server酱(微信) | `SERVERCHAN_KEY` | +| 企业微信机器人 | `WECOM_WEBHOOK` | +| PushPlus(微信) | `PUSHPLUS_TOKEN` | -## 开发说明 +--- -- 前端静态资源由 `main.py` 挂载 `prototype/` 目录,修改前端后刷新浏览器即可。 -- 自选股列表持久化在 `backend/watchlist.json`。 -- AkShare 不可用时部分接口会降级为 mock 数据,详见 `/api/health` 中的 `akshare` 字段。 -- 敏感文件(`.env`、虚拟环境等)已在 `.gitignore` 中排除,请勿提交密钥。 +## API 认证 -## 核心功能说明 +管理接口需要 JWT Token: -### 1. 智能选股增强 -可视化条件组合器,支持技术面、资金面、基本面多因子拖拽组合,选股结果可一键回测验证历史表现,策略可保存/分享并设置条件预警。详见 [智能选股增强使用说明.md](./智能选股增强使用说明.md) +```bash +# 登录 +curl -X POST http://localhost:8000/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username":"admin","password":"admin123"}' -### 2. 盘中异动雷达 -交易时段自动扫描快速拉升、放量突破、涨停打开、连板股等异动信号,支持多通道实时推送。详见 [盘中异动雷达使用说明.md](./盘中异动雷达使用说明.md) +# 携带 Token 访问 +curl http://localhost:8000/api/admin/status \ + -H "Authorization: Bearer YOUR_TOKEN" +``` -### 3. 板块轮动分析 -板块强弱趋势、资金流向桑基图、生命周期判断(启动期/加速期/衰退期)、龙头股自动识别、板块联动性分析。详见 [板块轮动分析使用说明.md](./板块轮动分析使用说明.md) +也可在 `.env` 中配置 `API_KEYS` 使用静态 API Key 模式。 -### 4. 策略回测增强 -多因子组合回测、仓位管理策略、参数优化网格搜索、策略对比(并排净值曲线)、完整风险指标(夏普/最大回撤/胜率)。详见 [策略回测增强使用说明.md](./策略回测增强使用说明.md) +--- -### 5. 持仓归因分析 -收益归因分解(选股能力 vs 择时能力 vs 运气成分)、持仓时长分析、买入理由有效性验证、情绪标签相关性、对标指数超额收益拆解。详见 [持仓归因分析深化使用说明.md](./持仓归因分析深化使用说明.md) +## 常见问题(WSL) -### 6. AI 对话式分析 -与大模型深度结合,支持自然语言选股、持仓诊断、策略建议、实时问答,多轮对话记住用户偏好。详见 [AI对话式分析使用说明.md](./AI对话式分析使用说明.md) +| 现象 | 处理 | +|---|---| +| `connection refused` | `sudo service postgresql start && sudo service redis-server start` | +| `password authentication failed` | 检查 `.env` 中 `PG_PASSWORD` 与数据库密码是否一致 | +| `python: command not found` | 使用 `python3` | +| 新开终端命令失效 | 先执行 `source .venv/bin/activate` | +| Redis 连接失败 | 不影响运行,自动降级到内存缓存 | +| 401 Unauthorized | 先调用 `/api/auth/login` 获取 Token | -### 7. 社区情绪监控 -爬取东方财富/雪球热帖,计算情绪指数(乐观/悲观比例)、热议股票排行、关键词云图、情绪与股价相关性回测。详见 [社区情绪监控使用说明.md](./社区情绪监控使用说明.md) - -### 8. 事件驱动策略 -财报发布前后统计规律、高管增减持跟踪、限售解禁影响分析、行业政策事件库、事件驱动选股。详见 [事件驱动策略使用说明.md](./事件驱动策略使用说明.md) - -### 9. 财报深度解读 -财报关键指标趋势、AI 一句话摘要、同行对比、财报异常预警(存货激增/应收账款占比过高)、发布日历提醒。详见 [财报深度解读使用说明.md](./财报深度解读使用说明.md) +--- ## 文档 -- [架构总结.md](./架构总结.md) — 分层设计、数据模型、AI 分析流程 -- [功能架构.md](./功能架构.md) — 功能模块详细说明 -- [待优化.md](./待优化.md) — 已知问题与优化方向 -- [功能扩展.md](./功能扩展.md) — 扩展功能建议 +| 文件 | 内容 | +|---|---| +| [架构总结.md](./架构总结.md) | 分层设计、数据模型、AI 分析流程 | +| [功能架构.md](./功能架构.md) | 功能模块详细说明 | +| [待优化.md](./待优化.md) | 已知问题与优化方向 | +| [功能扩展.md](./功能扩展.md) | 扩展功能建议 | +| [backend/ENV_CONFIG.md](backend/ENV_CONFIG.md) | 环境变量完整说明 | +| [backend/UPGRADE_GUIDE.md](backend/UPGRADE_GUIDE.md) | 升级指南 | -## 许可证 +--- -本项目仅供学习与研究使用。行情数据来源于第三方公开接口,请遵守相应数据源的使用条款。 +## 免责声明 + +本项目仅供学习与研究使用,不构成任何投资建议。行情数据来自第三方公开接口,请遵守相应数据源的使用条款。 diff --git a/backend/CHECKLIST.md b/backend/CHECKLIST.md new file mode 100644 index 0000000..3903247 --- /dev/null +++ b/backend/CHECKLIST.md @@ -0,0 +1,237 @@ +# 启动前检查清单 + +在首次启动或遇到问题时,请按此清单逐项检查。 + +## ✅ 环境检查 + +### 系统服务 +```bash +# PostgreSQL 是否运行 +sudo service postgresql status +sudo service postgresql start # 如果未运行 + +# Redis 是否运行 +redis-cli ping # 应返回 PONG +sudo service redis-server start # 如果未运行 +``` + +### Python 环境 +```bash +# 虚拟环境是否激活 +which python # 应显示 .venv/bin/python +source .venv/bin/activate # 如果未激活 + +# 依赖是否安装完整 +pip list | grep redis +pip list | grep jose +pip list | grep passlib +``` + +## ✅ 配置检查 + +### .env 文件 +```bash +# 检查必要配置项 +cat .env | grep PG_PASSWORD +cat .env | grep SECRET_KEY +cat .env | grep REDIS_HOST + +# 如果缺少配置 +cp .env.example .env # 如果 .env 不存在 +nano .env # 编辑配置 +``` + +### 必须配置的项 +- [ ] `PG_PASSWORD` - PostgreSQL 密码 +- [ ] `SECRET_KEY` - JWT 密钥(生产环境必须修改) +- [ ] `REDIS_HOST` - Redis 地址(默认 localhost) + +## ✅ 数据库检查 + +```bash +# 检查数据库是否创建 +psql -U postgres -c "\l" | grep stock_cs + +# 检查用户表是否存在 +psql -U postgres -d stock_cs -c "\dt" | grep users + +# 如果数据库或表不存在 +python cli.py init +``` + +## ✅ 权限检查 + +```bash +# 测试数据库连接 +psql -U postgres -d stock_cs -c "SELECT 1;" + +# 测试 Redis 连接 +redis-cli ping + +# 如果提示权限错误 +sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';" +``` + +## ✅ 启动服务 + +```bash +# 完整启动流程 +sudo service postgresql start +sudo service redis-server start +cd backend +source .venv/bin/activate +python main.py +``` + +### 启动成功标志 + +看到以下日志表示启动成功: +``` +✓ Redis 已连接: localhost:6379 +✓ 管理员账号已存在: admin +[startup] db + scheduler + auth ready +INFO: Uvicorn running on http://0.0.0.0:8000 +``` + +## ✅ 功能测试 + +```bash +# 1. 健康检查 +curl http://localhost:8000/api/health +# 应返回: {"ok":true,"akshare":true,"redis":true,"auth":true} + +# 2. 登录测试 +curl -X POST http://localhost:8000/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username":"admin","password":"admin123"}' +# 应返回 Token + +# 3. 运行完整测试 +python test_core_features.py +``` + +## 🔧 常见问题排查 + +### 问题 1: Redis 连接失败 +``` +✗ Redis 连接失败,缓存已禁用 +``` +**解决**: +```bash +sudo service redis-server start +redis-cli ping +``` + +### 问题 2: 数据库连接失败 +``` +connection refused +``` +**解决**: +```bash +sudo service postgresql start +psql -U postgres -c "SELECT 1;" +``` + +### 问题 3: 密码认证失败 +``` +password authentication failed +``` +**解决**: +```bash +sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';" +# 然后在 .env 中设置相同的密码 +``` + +### 问题 4: 模块未找到 +``` +ModuleNotFoundError: No module named 'redis' +``` +**解决**: +```bash +source .venv/bin/activate +pip install -r requirements.txt +``` + +### 问题 5: 401 Unauthorized +``` +{"detail":"未认证,请先登录"} +``` +**解决**: +```bash +# 先登录获取 Token +curl -X POST http://localhost:8000/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username":"admin","password":"admin123"}' + +# 使用 Token 访问 +curl -X GET http://localhost:8000/api/admin/status \ + -H "Authorization: Bearer <返回的token>" +``` + +### 问题 6: 用户表不存在 +``` +relation "users" does not exist +``` +**解决**: +```bash +python cli.py init +``` + +## 📝 首次部署完整流程 + +```bash +# 1. 系统依赖 +sudo apt update +sudo apt install -y postgresql redis-server + +# 2. 启动服务 +sudo service postgresql start +sudo service redis-server start + +# 3. 配置数据库密码 +sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';" + +# 4. Python 环境 +cd backend +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt + +# 5. 配置环境变量 +cp .env.example .env +nano .env # 编辑 PG_PASSWORD, SECRET_KEY 等 + +# 6. 初始化数据库 +python cli.py init + +# 7. 启动服务 +python main.py + +# 8. 测试 +python test_core_features.py +``` + +## 🚀 一键启动脚本(WSL) + +创建 `start.sh`: +```bash +#!/bin/bash +sudo service postgresql start +sudo service redis-server start +cd /mnt/e/project/stock_cs_v1/backend # 修改为实际路径 +source .venv/bin/activate +python main.py +``` + +使用: +```bash +chmod +x start.sh +./start.sh +``` + +## 📞 获取帮助 + +如果以上方法都无法解决问题: +1. 查看详细文档: `backend/UPGRADE_GUIDE.md` +2. 查看配置说明: `backend/ENV_CONFIG.md` +3. 查看实现总结: `三大核心功能实现总结.md` diff --git a/backend/ENV_CONFIG.md b/backend/ENV_CONFIG.md new file mode 100644 index 0000000..5a9c739 --- /dev/null +++ b/backend/ENV_CONFIG.md @@ -0,0 +1,118 @@ +# 环境变量配置说明 + +## 新增配置项(三大核心功能) + +### Redis 缓存配置 +```bash +# Redis 连接配置 +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD= # 如果 Redis 设置了密码则填写 +``` + +### 认证系统配置 +```bash +# JWT 密钥(生产环境务必修改) +SECRET_KEY=your-secret-key-change-in-production + +# Token 过期时间(分钟) +ACCESS_TOKEN_EXPIRE_MINUTES=10080 # 默认7天 + +# API Key 模式(可选,用于外部调用,逗号分隔多个) +API_KEYS=your-api-key-1,your-api-key-2 + +# 默认管理员账号(首次启动时创建) +DEFAULT_ADMIN_USERNAME=admin +DEFAULT_ADMIN_PASSWORD=admin123 # 首次启动后务必修改密码 +``` + +## 安装 Redis(WSL 环境) + +```bash +# 安装 Redis +sudo apt update +sudo apt install -y redis-server + +# 启动 Redis +sudo service redis-server start + +# 验证 Redis 是否运行 +redis-cli ping +# 应返回 PONG + +# 设置 Redis 开机自启(可选) +sudo systemctl enable redis-server +``` + +## 安全建议 + +1. **SECRET_KEY**: 生产环境必须使用强随机字符串,可用以下命令生成: + ```bash + python -c "import secrets; print(secrets.token_urlsafe(32))" + ``` + +2. **默认密码**: 首次登录后立即修改 admin 密码 + +3. **API_KEYS**: 仅在需要外部调用时配置,妥善保管 + +4. **Redis 密码**: 生产环境建议为 Redis 设置密码: + ```bash + # 编辑 Redis 配置 + sudo nano /etc/redis/redis.conf + # 找到 requirepass 行,取消注释并设置密码 + requirepass your-strong-password + # 重启 Redis + sudo service redis-server restart + ``` + +## 功能说明 + +### 1. Redis 缓存 +- 替代内存缓存,支持持久化和跨进程共享 +- 自动降级:Redis 不可用时使用内存缓存 +- 默认过期时间根据数据类型自动设置(行情数据1分钟,基本面数据1天) + +### 2. 统一鉴权 +- **JWT Token 模式**: 用户登录获取 Token,适合前端应用 +- **API Key 模式**: 用于外部系统调用,配置在 HTTP Header `X-API-Key` +- 管理接口(`/api/admin/*`)需要管理员权限 + +### 3. 统一异常处理 +- 业务异常返回友好错误信息 +- 数据源异常自动降级 +- 数据库异常统一处理 +- 所有异常记录日志便于排查 + +## API 使用示例 + +### 登录 +```bash +curl -X POST http://localhost:8000/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username":"admin","password":"admin123"}' +``` + +### 使用 Token 访问受保护接口 +```bash +curl -X GET http://localhost:8000/api/admin/status \ + -H "Authorization: Bearer YOUR_TOKEN_HERE" +``` + +### 使用 API Key 访问 +```bash +curl -X GET http://localhost:8000/api/admin/status \ + -H "X-API-Key: your-api-key-1" +``` + +## 注意事项 + +1. WSL 环境下,每次重启后需要手动启动服务: + ```bash + sudo service postgresql start + sudo service redis-server start + ``` + +2. Redis 连接失败不会影响系统运行,会自动降级到内存缓存 + +3. 未配置鉴权时,所有接口默认不需要认证(开发模式) diff --git a/backend/UPGRADE_GUIDE.md b/backend/UPGRADE_GUIDE.md new file mode 100644 index 0000000..873f8f7 --- /dev/null +++ b/backend/UPGRADE_GUIDE.md @@ -0,0 +1,287 @@ +# 三大核心功能升级指南 + +本次升级新增了三个核心功能:Redis 缓存层、统一鉴权机制、统一异常处理中间件。 + +## 1. 安装新依赖 + +```bash +cd backend +source .venv/bin/activate # 激活虚拟环境 +pip install -r requirements.txt +``` + +新增的依赖包: +- `redis>=5.0.0` - Redis 客户端 +- `python-jose[cryptography]>=3.3.0` - JWT Token 生成和验证 +- `passlib[bcrypt]>=1.7.4` - 密码哈希 +- `python-multipart>=0.0.9` - 表单数据解析 + +## 2. 安装和启动 Redis(WSL) + +```bash +# 安装 Redis +sudo apt update +sudo apt install -y redis-server + +# 启动 Redis +sudo service redis-server start + +# 验证 Redis 是否运行 +redis-cli ping +# 应返回 PONG +``` + +## 3. 配置环境变量 + +编辑 `backend/.env` 文件,添加以下配置: + +```bash +# ============ Redis 缓存配置 ============ +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD= # 可选,如果设置了密码则填写 + +# ============ 认证系统配置 ============ +# JWT 密钥(务必修改为强随机字符串) +SECRET_KEY=your-secret-key-change-in-production + +# Token 过期时间(分钟,默认7天) +ACCESS_TOKEN_EXPIRE_MINUTES=10080 + +# API Key(可选,用于外部调用,逗号分隔) +API_KEYS= + +# 默认管理员账号 +DEFAULT_ADMIN_USERNAME=admin +DEFAULT_ADMIN_PASSWORD=admin123 +``` + +### 生成安全的 SECRET_KEY + +```bash +python -c "import secrets; print(secrets.token_urlsafe(32))" +``` + +将输出的字符串替换 `SECRET_KEY` 的值。 + +## 4. 初始化数据库(创建用户表和默认管理员) + +```bash +cd backend +source .venv/bin/activate +python cli.py init +``` + +你会看到类似输出: +``` +✓ 创建默认管理员: admin +init done +``` + +## 5. 启动服务 + +```bash +# 确保 PostgreSQL 和 Redis 都在运行 +sudo service postgresql start +sudo service redis-server start + +# 启动 FastAPI 服务 +cd backend +source .venv/bin/activate +python main.py +``` + +启动日志应包含: +``` +✓ Redis 已连接: localhost:6379 +✓ 管理员账号已存在: admin +[startup] db + scheduler + auth ready +``` + +## 6. 测试功能 + +### 测试健康检查(查看 Redis 和鉴权状态) + +```bash +curl http://localhost:8000/api/health +``` + +应返回: +```json +{ + "ok": true, + "akshare": true, + "redis": true, + "auth": true +} +``` + +### 测试登录 + +```bash +curl -X POST http://localhost:8000/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username":"admin","password":"admin123"}' +``` + +应返回 Token: +```json +{ + "ok": true, + "access_token": "eyJ0eXAiOiJKV1QiLCJhb...", + "token_type": "bearer", + "username": "admin", + "is_admin": true +} +``` + +### 测试受保护的接口 + +```bash +# 使用上一步获取的 Token +curl -X GET http://localhost:8000/api/admin/status \ + -H "Authorization: Bearer YOUR_TOKEN_HERE" +``` + +### 测试 Redis 缓存 + +```bash +# 第一次请求(缓存未命中,较慢) +time curl http://localhost:8000/api/indices + +# 第二次请求(缓存命中,应该快很多) +time curl http://localhost:8000/api/indices +``` + +## 7. 修改默认密码(重要!) + +首次登录后,立即修改默认密码: + +```bash +curl -X POST http://localhost:8000/api/auth/change-password \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"old_password":"admin123","new_password":"your-new-strong-password"}' +``` + +## 8. 常见问题 + +### Redis 连接失败 + +如果看到以下日志: +``` +✗ Redis 连接失败,缓存已禁用: ... +``` + +**解决方法**: +```bash +# 检查 Redis 是否运行 +sudo service redis-server status + +# 如果未运行,启动它 +sudo service redis-server start +``` + +**注意**:Redis 连接失败不会影响系统运行,会自动降级到内存缓存。 + +### 401 Unauthorized 错误 + +如果访问管理接口返回 401: +```json +{"detail":"未认证,请先登录"} +``` + +**原因**:该接口需要认证 + +**解决方法**: +1. 先调用 `/api/auth/login` 获取 Token +2. 在请求头中加入 `Authorization: Bearer YOUR_TOKEN` + +### WSL 重启后服务启动失败 + +WSL 每次重启后需要手动启动服务: + +```bash +# 一键启动所有服务 +sudo service postgresql start && \ +sudo service redis-server start && \ +cd /mnt/e/project/stock_cs_v1/backend && \ +source .venv/bin/activate && \ +python main.py +``` + +## 9. 功能详细说明 + +### Redis 缓存优势 + +- **持久化**:服务重启后缓存不丢失 +- **跨进程**:多个进程可共享缓存 +- **性能提升**:大幅减少 AkShare API 调用,响应速度提升 10-100 倍 +- **自动降级**:Redis 不可用时自动使用内存缓存 + +### 鉴权机制 + +**支持两种认证方式**: + +1. **JWT Token**(推荐用于前端) + - 登录后获取 Token + - Token 有效期 7 天(可配置) + - 在 HTTP Header 中传递:`Authorization: Bearer TOKEN` + +2. **API Key**(推荐用于外部系统) + - 在 `.env` 中配置 `API_KEYS` + - 在 HTTP Header 中传递:`X-API-Key: YOUR_API_KEY` + +**受保护的接口**: +- `/api/admin/*` - 需要管理员权限 +- 其他接口暂不需要认证(可根据需要扩展) + +### 异常处理 + +系统现在会返回友好的错误信息: + +```json +{ + "success": false, + "error": "数据源异常,请稍后重试", + "code": 503 +} +``` + +错误类型: +- **400** - 业务逻辑错误 +- **401** - 未认证 +- **403** - 权限不足 +- **422** - 请求参数错误 +- **500** - 服务器内部错误 +- **503** - 数据源不可用 + +## 10. 升级检查清单 + +- [ ] 安装新依赖包 +- [ ] 安装并启动 Redis +- [ ] 配置 `.env` 文件(Redis + 鉴权配置) +- [ ] 生成安全的 `SECRET_KEY` +- [ ] 运行 `python cli.py init` 创建用户表 +- [ ] 启动服务并验证功能 +- [ ] 测试登录和受保护接口 +- [ ] 修改默认管理员密码 +- [ ] 检查 Redis 缓存是否生效 + +## 11. 回退方案 + +如果升级后遇到问题,可以临时禁用新功能: + +1. **禁用 Redis 缓存**:停止 Redis 服务,系统会自动降级到内存缓存 + ```bash + sudo service redis-server stop + ``` + +2. **禁用鉴权**:暂时注释掉 `main.py` 中受保护接口的 `Depends(require_auth)` 参数 + +3. **完整回退**:切换到升级前的 git 分支 + +--- + +升级完成后,系统将具备更强的性能、安全性和稳定性! diff --git a/backend/ai.py b/backend/ai.py index 92b9239..1c01f5f 100644 --- a/backend/ai.py +++ b/backend/ai.py @@ -202,6 +202,176 @@ def diagnose(symbol): return base +# ============ 走势分析(右键K线) ============ +def trend_analysis(symbol: str, date: str, period: str = "daily"): + """分析某只股票在指定日期(或最新)附近暴涨/暴跌的原因。 + period: daily / weekly / monthly + """ + with get_session() as s: + sec = s.get(Security, symbol) + # 取该日期前后一段数据作为上下文 + if date: + try: + target_date = dt.date.fromisoformat(date) + except Exception: + target_date = None + else: + target_date = None + + # 取最近60根K线 + rows = s.execute( + select(DailyQuote).where(DailyQuote.code == symbol) + .order_by(DailyQuote.date.desc()).limit(60) + ).scalars().all() + rows = list(reversed(rows)) + + # 当日及相邻数据 + target_row = None + if target_date and rows: + # 找最近的一根 + closest = min(rows, key=lambda r: abs((r.date - target_date).days)) + if abs((closest.date - target_date).days) <= 7: + target_row = closest + + if not target_row and rows: + target_row = rows[-1] + + m = s.get(StockMetric, symbol) + + name = (sec.name if sec else (m.name if m else symbol)) or symbol + + # 计算目标K线的涨跌幅 + pct = 0.0 + if target_row and rows: + idx = rows.index(target_row) + if idx > 0: + prev_close = rows[idx - 1].close + if prev_close: + pct = round((target_row.close - prev_close) / prev_close * 100, 2) + + # 拉取相关新闻 + import akshare_service as svc + try: + news_data = svc.get_stock_news(symbol, limit=10) + news_items = news_data.get("list", []) + except Exception: + news_items = [] + + # 拉取RAG上下文 + import rag + rctx = rag.stock_context(symbol, limit=6) + + # 构造上下文数据 + period_cn = {"daily": "日K", "weekly": "周K", "monthly": "月K"}.get(period, "K线") + date_str = target_row.date.isoformat() if target_row else (date or "最新") + + # 当日技术面 + if target_row: + tech_line = ( + f"目标K线:{date_str},开{target_row.open} 收{target_row.close} " + f"高{target_row.high} 低{target_row.low}," + f"涨跌幅{pct:+.2f}%,成交量{target_row.volume:,}" + ) + else: + tech_line = f"目标日期:{date_str},暂无日线数据" + + # 前后走势(最近5根) + if target_row and rows: + idx = rows.index(target_row) + window = rows[max(0, idx-4):idx+2] + trend_line = "前后走势:" + " → ".join( + f"{r.date.strftime('%m/%d')}({'↑' if i == 0 or r.close >= rows[rows.index(r)-1].close else '↓'}{abs(round((r.close/rows[rows.index(r)-1].close-1)*100,1)) if rows.index(r) > 0 else 0}%)" + for i, r in enumerate(window) + ) + else: + trend_line = "" + + # 均线状态 + ma_line = "" + if m: + ma_line = (f"均线状态:MA5={m.ma5} MA10={m.ma10} MA20={m.ma20} MA60={m.ma60}," + f"{'多头排列' if m.ma_bull else '非多头'}," + f"量比{m.vol_ratio},RSI14={m.rsi14}") + + # 新闻摘要 + news_block = "" + if news_items: + news_block = "相关新闻(近期):\n" + "\n".join( + f"- [{n.get('time','')[:10]}] {n.get('title','')}" for n in news_items[:6] + ) + + # 判断是否暴涨/暴跌 + move_desc = "" + if abs(pct) >= 5: + move_desc = f"该股{'暴涨' if pct > 0 else '暴跌'} {abs(pct):.2f}%({'接近/涨停' if pct >= 9.5 else '显著上涨' if pct > 0 else '接近/跌停' if pct <= -9.5 else '显著下跌'})" + elif abs(pct) >= 2: + move_desc = f"该股{'上涨' if pct > 0 else '下跌'} {abs(pct):.2f}%" + else: + move_desc = f"该股小幅变动 {pct:+.2f}%" + + facts = f"""{name}({symbol}){period_cn}走势分析 +分析日期:{date_str} +{move_desc} +{tech_line} +{trend_line} +{ma_line} +{news_block} +消息面情绪:{rctx['tone']} +{rctx['block'] or ''}""" + + if llm.enabled(): + try: + prompt = ( + f"请分析 {name}({symbol})在 {date_str} 前后{period_cn}的走势," + f"重点解释:① 为什么{'暴涨' if pct >= 5 else ('暴跌' if pct <= -5 else '出现此走势')}(从技术面、资金面、政策面、新闻事件等维度);" + f"② 背后的主要驱动逻辑是什么;③ 后续需关注的信号或风险。250字以内,分点清晰。\n\n{facts}" + ) + text = llm.ask(prompt, temperature=0.5, max_tokens=600) + return {"ok": True, "source": "llm", "symbol": symbol, "name": name, + "date": date_str, "period": period, "pct": pct, "facts": facts, "text": text} + except Exception: + pass + + # 规则降级 + reasons = [] + if m: + if m.ma_bull and pct > 0: + reasons.append("均线多头排列,趋势向上") + if m.vol_ratio >= 2 and pct > 0: + reasons.append(f"成交量显著放大(量比{m.vol_ratio}),主力资金介入") + if m.vol_ratio >= 2 and pct < 0: + reasons.append(f"放量下跌(量比{m.vol_ratio}),资金出逃信号") + if m.macd_gold and pct > 0: + reasons.append("MACD金叉,动能转强") + if m.rsi14 >= 80: + reasons.append(f"RSI超买({m.rsi14}),注意回调风险") + if m.rsi14 < 30: + reasons.append(f"RSI超卖({m.rsi14}),存在超跌反弹机会") + if m.pos60 >= 0.95 and pct > 0: + reasons.append("突破60日新高,动量突破") + if m.pos60 <= 0.1 and pct > 0: + reasons.append("低位反弹,超跌修复") + if rctx['tone'] == '利好': + reasons.append("近期资讯面偏利好") + elif rctx['tone'] == '利空': + reasons.append("近期资讯面偏利空") + if news_items: + hot_news = news_items[0]['title'][:40] + reasons.append(f"最新消息:{hot_news}…") + if not reasons: + reasons.append("暂无明确技术或消息面驱动,可能为市场情绪或板块联动") + + text = ( + f"{name} 在 {date_str} {move_desc}。\n" + f"主要原因分析:\n" + + "\n".join(f"{i+1}. {r}" for i, r in enumerate(reasons)) + + f"\n\n建议:{'关注量能是否持续配合,谨防高位回调。' if pct >= 5 else ('关注是否企稳止跌,底部确认前谨慎抄底。' if pct <= -5 else '走势相对平稳,跟踪板块动向。')}" + f"\n{DISCLAIMER}" + ) + return {"ok": True, "source": "rule", "symbol": symbol, "name": name, +"date": date_str, "period": period, "pct": pct, "facts": facts, "text": text} + + # ============ 今日策略 ============ def today_strategy(): with get_session() as s: diff --git a/backend/akshare_service.py b/backend/akshare_service.py index ca01c9f..0b1f4b0 100644 --- a/backend/akshare_service.py +++ b/backend/akshare_service.py @@ -12,6 +12,7 @@ from functools import wraps from cachetools import TTLCache import requests +from redis_cache import cache try: import akshare as ak @@ -25,16 +26,35 @@ _cache = TTLCache(maxsize=256, ttl=30) def cached(ttl: int): + """缓存装饰器:优先使用 Redis,降级到内存缓存""" def deco(fn): local = TTLCache(maxsize=64, ttl=ttl) @wraps(fn) def wrapper(*args, **kwargs): - key = (fn.__name__, args, tuple(sorted(kwargs.items()))) - if key in local: - return local[key] + # 生成缓存键 + key = f"akshare:{fn.__name__}:{args}:{tuple(sorted(kwargs.items()))}" + + # 优先从 Redis 读取 + if cache.enabled: + cached_value = cache.get(key) + if cached_value is not None: + return cached_value + + # Redis 未命中,从内存缓存读取 + local_key = (fn.__name__, args, tuple(sorted(kwargs.items()))) + if local_key in local: + return local[local_key] + + # 执行函数 val = fn(*args, **kwargs) - local[key] = val + + # 写入 Redis + if cache.enabled: + cache.set(key, val, expire=ttl) + + # 写入内存缓存(降级) + local[local_key] = val return val return wrapper @@ -183,7 +203,19 @@ def get_stock_news(code: str, limit: int = 12): return {"source": "mock", "list": []} +# 已知指数代码 → 新浪前缀映射 +_INDEX_CODES = {"000001", "000300", "000016", "399001", "399006", "899050"} + +def _is_index(code: str) -> bool: + return code in _INDEX_CODES or code.startswith(("sh0", "sz3990", "bj8990")) + def _sina_symbol(code: str) -> str: + if code in ("000001", "000016"): # 上证系列 + return "sh" + code + if code in ("000300",): # 沪深300 + return "sh" + code + if code in ("399001", "399006"): # 深证 + return "sz" + code if code.startswith("6"): return "sh" + code if code.startswith(("0", "3")): @@ -194,9 +226,23 @@ def _sina_symbol(code: str) -> str: @cached(60) -def get_kline(symbol: str = "600519", days: int = 120): +def get_kline(symbol: str = "000001", days: int = 120): if AK_OK: - # 主源:新浪日线(更稳定);备源:腾讯 + # 指数走专用接口 + if symbol in _INDEX_CODES: + try: + sym = _sina_symbol(symbol) + df = ak.stock_zh_index_daily(symbol=sym) + if df is not None and not df.empty: + df = df.tail(days) + dates = [str(d)[5:].replace("-", "/") for d in df["date"]] + ohlc = [[float(r["open"]), float(r["close"]), float(r["low"]), float(r["high"])] + for _, r in df.iterrows()] + vols = [int(r["volume"]) if "volume" in df.columns else 0 for _, r in df.iterrows()] + return {"source": "akshare", "symbol": symbol, "dates": dates, "ohlc": ohlc, "vols": vols} + except Exception: + pass + # 个股主源:新浪日线(更稳定);备源:腾讯 for src in ("sina", "tx"): try: sym = _sina_symbol(symbol) @@ -321,6 +367,90 @@ def get_treemap(mode: str = "sector"): return {"source": boards["source"], "mode": "sector", "items": items} +@cached(120) +def get_us_treemap(): + """美股热门板块云图(按成交额取前100只)""" + if AK_OK: + try: + df = ak.stock_us_spot_em() + if df is not None and not df.empty: + top = df.sort_values("成交额", ascending=False).head(100) + items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e8, 2), + "pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()] + items = [x for x in items if x["name"]] + return {"source": "akshare", "market": "us", "items": items} + except Exception: + pass + names = ["苹果","微软","谷歌","亚马逊","英伟达","特斯拉","Meta","台积电","巴菲特","摩根"] + return {"source": "mock", "market": "us", + "items": [{"name": n, "value": _rnd(10,200), "pct": round(_rnd(-4,4),2)} for n in names]} + + +@cached(120) +def get_hk_treemap(): + """港股热门板块云图(按成交额取前100只)""" + if AK_OK: + try: + df = ak.stock_hk_spot_em() + if df is not None and not df.empty: + top = df.sort_values("成交额", ascending=False).head(100) + items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e4, 2), + "pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()] + items = [x for x in items if x["name"]] + return {"source": "akshare", "market": "hk", "items": items} + except Exception: + pass + names = ["腾讯","阿里巴巴","美团","京东","小米","百度","网易","中国平安","汇丰","友邦"] + return {"source": "mock", "market": "hk", + "items": [{"name": n, "value": _rnd(5,100), "pct": round(_rnd(-4,4),2)} for n in names]} + + +@cached(120) +def get_all_sector_leaders(top_n: int = 5): + """一次性获取所有板块的前N只龙头股""" + boards = get_industry_boards() + result = {} + for b in boards.get("list", []): + name = b["name"] + try: + r = get_sector_stocks(name, top_n + 1) + result[name] = r.get("stocks", [])[:top_n] + except Exception: + result[name] = [] + return {"source": "akshare", "sectors": result} + + +@cached(300) +def get_sector_stocks(sector_name: str, limit: int = 20): + """获取板块成分股,按成交额排序""" + if AK_OK: + try: + df = ak.stock_board_industry_cons_em(symbol=sector_name) + if df is not None and not df.empty: + if "成交额" in df.columns: + df = df.sort_values("成交额", ascending=False) + stocks = [] + for _, r in df.head(limit).iterrows(): + try: + stocks.append({ + "code": str(r.get("代码", "")), + "name": str(r.get("名称", "")), + "pct": round(float(r.get("涨跌幅", 0)), 2), + "price": round(float(r.get("最新价", 0)), 2), + "amount": round(float(r.get("成交额", 0)) / 1e8, 2), + }) + except Exception: + continue + return {"source": "akshare", "name": sector_name, "stocks": stocks} + except Exception: + pass + # mock + stocks = [{"code": f"60000{i}", "name": f"{sector_name}{i+1}", + "pct": round(_rnd(-5, 5), 2), "price": round(_rnd(5, 100), 2), "amount": round(_rnd(1, 50), 2)} + for i in range(10)] + return {"source": "mock", "name": sector_name, "stocks": stocks} + + # ============================================================ # 资金流向(行业) # ============================================================ diff --git a/backend/auth.py b/backend/auth.py new file mode 100644 index 0000000..0c02c13 --- /dev/null +++ b/backend/auth.py @@ -0,0 +1,88 @@ +from datetime import datetime, timedelta +from typing import Optional +from fastapi import Depends, HTTPException, status, Header +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jose import JWTError, jwt +import bcrypt +from sqlalchemy.orm import Session +from db import SessionLocal +from models import User +from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, API_KEYS + + +security = HTTPBearer(auto_error=False) + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证密码""" + return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) + +def get_password_hash(password: str) -> str: + """密码哈希""" + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + """生成 JWT Token""" + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: + """验证用户名密码""" + user = db.query(User).filter(User.username == username).first() + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user + +async def get_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), + x_api_key: Optional[str] = Header(None), + db: Session = Depends(lambda: SessionLocal()) +) -> Optional[User]: + """获取当前用户(支持 JWT Token 和 API Key 两种方式)""" + # 方式1:API Key + if x_api_key and x_api_key in API_KEYS: + # API Key 模式,返回虚拟管理员 + user = db.query(User).filter(User.username == "admin").first() + if user: + return user + + # 方式2:JWT Token + if credentials: + token = credentials.credentials + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + return None + user = db.query(User).filter(User.username == username).first() + return user + except JWTError: + return None + + return None + +async def require_auth(current_user: Optional[User] = Depends(get_current_user)): + """需要认证的依赖""" + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="未认证,请先登录", + headers={"WWW-Authenticate": "Bearer"}, + ) + return current_user + +async def require_admin(current_user: User = Depends(require_auth)): + """需要管理员权限的依赖""" + if not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要管理员权限" + ) + return current_user \ No newline at end of file diff --git a/backend/cli.py b/backend/cli.py index 0cbe277..a5cada2 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -9,12 +9,16 @@ import sys from db import init_db import ingest +import init_auth +import watchlist_manager as wl def main(): init_db() args = sys.argv[1:] if not args or args[0] == "init": + init_auth.init_default_admin() + wl.init_default_groups() print("init done") return if args[0] == "ingest": diff --git a/backend/config.py b/backend/config.py index 2d744e3..4a2d711 100644 --- a/backend/config.py +++ b/backend/config.py @@ -51,3 +51,19 @@ SERVERCHAN_KEY = os.getenv("SERVERCHAN_KEY", "") WECOM_WEBHOOK = os.getenv("WECOM_WEBHOOK", "") # PushPlus(微信推送) PUSHPLUS_TOKEN = os.getenv("PUSHPLUS_TOKEN", "") + +# ---- Redis 缓存 ---- +REDIS_HOST = os.getenv("REDIS_HOST", "localhost") +REDIS_PORT = int(os.getenv("REDIS_PORT", "6379")) +REDIS_DB = int(os.getenv("REDIS_DB", "0")) +REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "") + +# ---- 鉴权配置 ---- +SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "10080")) # 7天 +# API Key 模式(可选,用于外部调用) +API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else [] +# 默认管理员账号(首次启动时创建) +DEFAULT_ADMIN_USERNAME = os.getenv("DEFAULT_ADMIN_USERNAME", "admin") +DEFAULT_ADMIN_PASSWORD = os.getenv("DEFAULT_ADMIN_PASSWORD", "admin123") diff --git a/backend/data_manager.py b/backend/data_manager.py new file mode 100644 index 0000000..1de1ef5 --- /dev/null +++ b/backend/data_manager.py @@ -0,0 +1,398 @@ +"""数据修正与回填增强:数据修正、断点续传、完整性检查、质量报告""" +import datetime as dt +import json +import os +from typing import List, Optional, Dict +from sqlalchemy import select, func, and_, delete +from db import get_session +from models import DailyQuote, StockMetric, Security, IndexDaily, SectorDaily, JobRun +import ingest + +# 回填进度文件路径 +PROGRESS_FILE = os.path.join(os.path.dirname(__file__), ".refill_progress.json") + + +# ============ 数据修正 ============ + +def delete_quote(code: str, date: str) -> Dict: + """删除指定股票指定日期的日线数据""" + d = dt.date.fromisoformat(date) + with get_session() as s: + row = s.execute( + select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d) + ).scalar_one_or_none() + if not row: + return {"ok": False, "msg": f"{code} {date} 无此数据"} + s.delete(row) + s.commit() + return {"ok": True, "msg": f"已删除 {code} {date} 日线"} + + +def update_quote(code: str, date: str, fields: Dict) -> Dict: + """修正指定股票指定日期的日线数据""" + allowed = {"open", "high", "low", "close", "volume", "amount"} + to_update = {k: v for k, v in fields.items() if k in allowed} + if not to_update: + return {"ok": False, "msg": "无有效修正字段"} + + d = dt.date.fromisoformat(date) + with get_session() as s: + row = s.execute( + select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d) + ).scalar_one_or_none() + if not row: + return {"ok": False, "msg": f"{code} {date} 无此数据"} + for k, v in to_update.items(): + setattr(row, k, v) + s.commit() + return {"ok": True, "updated": to_update} + + +def delete_quotes_range(code: str, start: str, end: str) -> Dict: + """删除指定股票日期范围内的日线数据""" + d_start = dt.date.fromisoformat(start) + d_end = dt.date.fromisoformat(end) + with get_session() as s: + rows = s.execute( + select(DailyQuote).where( + DailyQuote.code == code, + DailyQuote.date >= d_start, + DailyQuote.date <= d_end + ) + ).scalars().all() + count = len(rows) + for row in rows: + s.delete(row) + s.commit() + return {"ok": True, "deleted": count, "range": f"{start} ~ {end}"} + + +def refetch_quote(code: str, days: int = 30) -> Dict: + """重新抓取指定股票的日线数据(覆盖更新)""" + rows = ingest.fetch_daily(code, days) + if not rows: + return {"ok": False, "msg": f"抓取 {code} 数据失败"} + n = ingest.ingest_quotes([code], days=days) + return {"ok": True, "code": code, "rows": len(rows), "msg": f"已更新 {len(rows)} 条日线"} + + +# ============ 数据完整性检查 ============ + +def check_data_integrity(codes: Optional[List[str]] = None, days: int = 30) -> Dict: + """检查数据完整性,找出缺失数据的股票和日期""" + with get_session() as s: + # 确定检查范围 + latest = s.execute(select(func.max(DailyQuote.date))).scalar() + if not latest: + return {"ok": False, "msg": "数据库无日线数据"} + + start = latest - dt.timedelta(days=days) + + # 获取检查的股票列表 + if codes: + check_codes = codes + else: + # 默认检查有记录的所有股票 + all_codes = s.execute( + select(DailyQuote.code).where( + DailyQuote.date >= start + ).distinct() + ).scalars().all() + check_codes = list(all_codes)[:200] # 最多检查200只 + + # 统计每只股票的数据点数 + from sqlalchemy import case + code_counts = {} + for code in check_codes: + count = s.execute( + select(func.count()).select_from(DailyQuote) + .where(DailyQuote.code == code, DailyQuote.date >= start) + ).scalar() + code_counts[code] = count + + # 以最多数据量为基准(应是交易日数) + expected = max(code_counts.values()) if code_counts else 0 + + # 找出缺失数据的股票 + missing = [] + normal = [] + for code, count in code_counts.items(): + ratio = count / expected if expected > 0 else 0 + if ratio < 0.8: # 缺失超过20% + with get_session() as s2: + sec = s2.get(Security, code) + name = sec.name if sec else code + missing.append({ + "code": code, + "name": name, + "actual": count, + "expected": expected, + "missing": expected - count, + "missing_pct": round((1 - ratio) * 100, 1) + }) + else: + normal.append(code) + + missing.sort(key=lambda x: x["missing"], reverse=True) + + return { + "ok": True, + "check_range": f"{start.isoformat()} ~ {latest.isoformat()}", + "checked": len(check_codes), + "expected_days": expected, + "normal_count": len(normal), + "missing_count": len(missing), + "missing_stocks": missing[:50] + } + + +def auto_fix_missing(limit: int = 50) -> Dict: + """自动补齐缺失数据(批量重新抓取)""" + result = check_data_integrity(days=30) + if not result["ok"] or result["missing_count"] == 0: + return {"ok": True, "msg": "数据完整,无需修复", "fixed": 0} + + missing_stocks = result["missing_stocks"][:limit] + codes = [s["code"] for s in missing_stocks] + + with get_session() as s: + job = JobRun(job="auto_fix", status="running", + message=f"0/{len(codes)}") + s.add(job) + s.commit() + job_id = job.id + + fixed = 0 + failed = [] + try: + for i, code in enumerate(codes): + rows = ingest.fetch_daily(code, days=60) + if rows: + ingest.ingest_quotes([code], days=60) + fixed += 1 + else: + failed.append(code) + + if (i + 1) % 10 == 0: + with get_session() as s: + j = s.get(JobRun, job_id) + j.message = f"{i+1}/{len(codes)}" + s.commit() + + status = "success" + msg = f"修复 {fixed}/{len(codes)},失败 {len(failed)} 只" + except Exception as e: + status = "error" + msg = f"修复中断: {repr(e)[:160]}" + + with get_session() as s: + j = s.get(JobRun, job_id) + j.status = status + j.finished_at = dt.datetime.now() + j.message = msg + s.commit() + + return {"ok": True, "fixed": fixed, "failed": failed, "msg": msg} + + +# ============ 断点续传回填 ============ + +def _load_progress() -> Dict: + """加载回填进度""" + if os.path.exists(PROGRESS_FILE): + try: + with open(PROGRESS_FILE, "r") as f: + return json.load(f) + except Exception: + pass + return {} + + +def _save_progress(progress: Dict): + """保存回填进度""" + with open(PROGRESS_FILE, "w") as f: + json.dump(progress, f) + + +def _clear_progress(task_id: str): + """清除指定任务的进度""" + progress = _load_progress() + progress.pop(task_id, None) + _save_progress(progress) + + +def start_refill_with_resume(days: int = 250, task_id: str = "default") -> Dict: + """带断点续传的全市场回填""" + from akshare_service import _code_name_map + + cmap = _code_name_map() + all_codes = [c for c in cmap.keys() if c[:1] in ("0", "3", "6")] + total = len(all_codes) + + # 加载进度 + progress = _load_progress() + task_progress = progress.get(task_id, {"done_codes": [], "days": days}) + done_codes = set(task_progress.get("done_codes", [])) + + # 过滤已完成的股票 + remaining = [c for c in all_codes if c not in done_codes] + + with get_session() as s: + job = JobRun( + job="refill_resume", + status="running", + message=f"续传: 已完成 {len(done_codes)}/{total},剩余 {len(remaining)}" + ) + s.add(job) + s.commit() + job_id = job.id + + fixed = len(done_codes) + try: + for i in range(0, len(remaining), 50): + batch = remaining[i:i + 50] + ingest.ingest_quotes(batch, days=days, with_metrics=True, cmap=cmap) + fixed += len(batch) + + # 保存进度 + done_codes.update(batch) + progress[task_id] = { + "done_codes": list(done_codes), + "days": days, + "updated_at": dt.datetime.now().isoformat() + } + _save_progress(progress) + + with get_session() as s: + j = s.get(JobRun, job_id) + j.message = f"{fixed}/{total}" + s.commit() + + # 完成后清除进度 + _clear_progress(task_id) + status = "success" + msg = f"完成 {fixed}/{total}" + except Exception as e: + status = "error" + msg = f"中断于 {fixed}/{total} | {repr(e)[:160]}" + # 保留进度供续传 + + with get_session() as s: + j = s.get(JobRun, job_id) + j.status = status + j.finished_at = dt.datetime.now() + j.message = msg + s.commit() + + return {"ok": status == "success", "done": fixed, "total": total, "msg": msg} + + +def get_refill_progress(task_id: str = "default") -> Dict: + """获取回填进度""" + progress = _load_progress() + task = progress.get(task_id) + if not task: + return {"ok": True, "has_progress": False, "msg": "无回填进度记录"} + + from akshare_service import _code_name_map + cmap = _code_name_map() + total = len([c for c in cmap.keys() if c[:1] in ("0", "3", "6")]) + done = len(task.get("done_codes", [])) + + return { + "ok": True, + "has_progress": True, + "task_id": task_id, + "done": done, + "total": total, + "pct": round(done / total * 100, 1) if total > 0 else 0, + "updated_at": task.get("updated_at", "") + } + + +def clear_refill_progress(task_id: str = "default") -> Dict: + """清除回填进度(从头开始)""" + _clear_progress(task_id) + return {"ok": True, "msg": f"已清除任务 {task_id} 的进度"} + + +# ============ 数据质量报告 ============ + +def get_data_quality_report() -> Dict: + """生成数据质量报告""" + with get_session() as s: + # 基本统计 + total_quotes = s.execute(select(func.count()).select_from(DailyQuote)).scalar() or 0 + total_stocks = s.execute( + select(func.count(DailyQuote.code.distinct())) + ).scalar() or 0 + latest_date = s.execute(select(func.max(DailyQuote.date))).scalar() + earliest_date = s.execute(select(func.min(DailyQuote.date))).scalar() + + # 最近30天数据密度 + if latest_date: + start30 = latest_date - dt.timedelta(days=30) + recent_stocks = s.execute( + select(func.count(DailyQuote.code.distinct())) + .where(DailyQuote.date >= start30) + ).scalar() or 0 + + recent_dates = s.execute( + select(func.count(DailyQuote.date.distinct())) + .where(DailyQuote.date >= start30) + ).scalar() or 0 + else: + recent_stocks = 0 + recent_dates = 0 + + # 异常数据检测(开盘价为0的记录) + zero_open = s.execute( + select(func.count()).select_from(DailyQuote) + .where(DailyQuote.open == 0) + ).scalar() or 0 + + # 最近任务状态 + recent_jobs = s.execute( + select(JobRun).order_by(JobRun.id.desc()).limit(5) + ).scalars().all() + + jobs_summary = [{ + "job": j.job, + "status": j.status, + "started": j.started_at.strftime("%m-%d %H:%M") if j.started_at else "", + "message": j.message[:100] + } for j in recent_jobs] + + # 数据健康度评分 + score = 100 + issues = [] + + if zero_open > 0: + score -= min(20, zero_open // 100) + issues.append(f"存在 {zero_open} 条开盘价为0的异常数据") + + if total_stocks < 100: + score -= 30 + issues.append(f"入库股票数量偏少({total_stocks}只)") + + if latest_date and (dt.date.today() - latest_date).days > 7: + score -= 20 + issues.append(f"数据滞后 {(dt.date.today() - latest_date).days} 天") + + return { + "ok": True, + "generated_at": dt.datetime.now().isoformat(), + "health_score": max(0, score), + "issues": issues, + "statistics": { + "total_quotes": total_quotes, + "total_stocks": total_stocks, + "latest_date": latest_date.isoformat() if latest_date else None, + "earliest_date": earliest_date.isoformat() if earliest_date else None, + "data_span_days": (latest_date - earliest_date).days if latest_date and earliest_date else 0, + "recent_30d_stocks": recent_stocks, + "recent_30d_dates": recent_dates, + "zero_open_count": zero_open + }, + "recent_jobs": jobs_summary + } diff --git a/backend/exceptions.py b/backend/exceptions.py new file mode 100644 index 0000000..983a0d3 --- /dev/null +++ b/backend/exceptions.py @@ -0,0 +1,73 @@ +from fastapi import Request, status +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from sqlalchemy.exc import SQLAlchemyError +import traceback + +class BusinessException(Exception): + """业务异常基类""" + def __init__(self, message: str, code: int = 400): + self.message = message + self.code = code + super().__init__(self.message) + +class DataSourceException(BusinessException): + """数据源异常(AkShare等)""" + def __init__(self, message: str = "数据源异常,请稍后重试"): + super().__init__(message, code=503) + +class AuthException(BusinessException): + """认证异常""" + def __init__(self, message: str = "认证失败"): + super().__init__(message, code=401) + +class PermissionException(BusinessException): + """权限异常""" + def __init__(self, message: str = "权限不足"): + super().__init__(message, code=403) + +async def business_exception_handler(request: Request, exc: BusinessException): + """业务异常处理器""" + return JSONResponse( + status_code=exc.code, + content={ + "success": False, + "error": exc.message, + "code": exc.code + } + ) + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """请求参数验证异常处理器""" + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "success": False, + "error": "请求参数错误", + "detail": exc.errors() + } + ) + +async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): + """数据库异常处理器""" + print(f"数据库错误: {exc}") + traceback.print_exc() + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "success": False, + "error": "数据库操作失败" + } + ) + +async def general_exception_handler(request: Request, exc: Exception): + """通用异常处理器""" + print(f"未捕获异常: {type(exc).__name__}: {exc}") + traceback.print_exc() + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "success": False, + "error": "服务器内部错误" + } + ) \ No newline at end of file diff --git a/backend/ingest.py b/backend/ingest.py index fdb05ac..68d9581 100644 --- a/backend/ingest.py +++ b/backend/ingest.py @@ -12,7 +12,7 @@ import akshare_service as svc import config from db import get_session from models import (DailyQuote, DragonTiger, FundFlowDaily, IndexDaily, JobRun, - SectorDaily, Security, SentimentDaily, StockMetric) + SectorDaily, SectorLeader, Security, SentimentDaily, StockMetric) try: import akshare as ak @@ -229,6 +229,25 @@ def ingest_sectors(): return n +def ingest_sector_leaders(): + """入库各板块前5龙头股(按成交额)""" + d = _today() + data = svc.get_all_sector_leaders(top_n=5) + rows = [] + for sector, stocks in data.get("sectors", {}).items(): + for i, s in enumerate(stocks): + rows.append({"date": d, "sector": sector, "code": s["code"], + "name": s["name"], "pct": s["pct"], + "price": s["price"], "amount": s["amount"], "rank": i + 1}) + if not rows: + return 0 + with get_session() as s: + n = _upsert(s, SectorLeader, rows, ["date", "sector", "code"], + ["name", "pct", "price", "amount", "rank"]) + s.commit() + return n + + def ingest_fund_flow(): data = svc.get_fund_flow() d = _today() @@ -280,6 +299,7 @@ def run_daily_ingest(universe=None, with_quotes=True): summary["securities"] = ingest_securities() summary["indices"] = ingest_indices() summary["sectors"] = ingest_sectors() + summary["sector_leaders"] = ingest_sector_leaders() summary["fund_flow"] = ingest_fund_flow() summary["sentiment"] = ingest_sentiment() summary["dragon"] = ingest_dragon() diff --git a/backend/init_auth.py b/backend/init_auth.py new file mode 100644 index 0000000..6c8180a --- /dev/null +++ b/backend/init_auth.py @@ -0,0 +1,21 @@ +from db import get_session +from models import User +from auth import get_password_hash +from config import DEFAULT_ADMIN_USERNAME, DEFAULT_ADMIN_PASSWORD + +def init_default_admin(): + """创建默认管理员账号(如果不存在)""" + with get_session() as s: + admin = s.query(User).filter(User.username == DEFAULT_ADMIN_USERNAME).first() + if not admin: + admin = User( + username=DEFAULT_ADMIN_USERNAME, + hashed_password=get_password_hash(DEFAULT_ADMIN_PASSWORD), + is_admin=True, + is_active=True + ) + s.add(admin) + s.commit() + print(f"✓ 创建默认管理员: {DEFAULT_ADMIN_USERNAME}") + else: + print(f"✓ 管理员账号已存在: {DEFAULT_ADMIN_USERNAME}") \ No newline at end of file diff --git a/backend/install.sh b/backend/install.sh new file mode 100644 index 0000000..8490cb7 --- /dev/null +++ b/backend/install.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# 三大核心功能快速安装脚本(WSL/Linux) + +set -e + +echo "=========================================" +echo " Blackdata StockTerminal 核心功能安装" +echo "=========================================" +echo "" + +# 检查是否在 WSL/Linux 环境 +if [[ "$OSTYPE" != "linux-gnu"* ]]; then + echo "⚠ 此脚本仅支持 WSL/Linux 环境" + exit 1 +fi + +# 1. 安装系统依赖 +echo "[1/6] 检查并安装系统依赖..." +sudo apt update +sudo apt install -y postgresql postgresql-contrib redis-server python3-pip python3-venv + +# 2. 启动服务 +echo "" +echo "[2/6] 启动 PostgreSQL 和 Redis..." +sudo service postgresql start +sudo service redis-server start + +# 验证服务 +if redis-cli ping > /dev/null 2>&1; then + echo "✓ Redis 运行正常" +else + echo "⚠ Redis 启动失败,缓存将降级到内存模式" +fi + +# 3. 创建虚拟环境(如果不存在) +if [ ! -d ".venv" ]; then + echo "" + echo "[3/6] 创建 Python 虚拟环境..." + python3 -m venv .venv +else + echo "" + echo "[3/6] 虚拟环境已存在,跳过创建" +fi + +# 4. 安装 Python 依赖 +echo "" +echo "[4/6] 安装 Python 依赖包..." +source .venv/bin/activate +pip install --upgrade pip +pip install -r requirements.txt + +# 5. 配置环境变量 +echo "" +echo "[5/6] 配置环境变量..." +if [ ! -f ".env" ]; then + if [ -f ".env.example" ]; then + cp .env.example .env + echo "✓ 已从 .env.example 创建 .env 文件" + else + echo "⚠ .env.example 不存在,请手动创建 .env 文件" + fi + + # 生成随机 SECRET_KEY + SECRET_KEY=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))") + echo "" + echo "生成的 SECRET_KEY(请添加到 .env):" + echo "SECRET_KEY=$SECRET_KEY" + echo "" +else + echo "✓ .env 文件已存在" +fi + +# 6. 初始化数据库 +echo "" +echo "[6/6] 初始化数据库..." + +# 检查 PostgreSQL 密码配置 +if grep -q "PG_PASSWORD=your_password" .env 2>/dev/null || grep -q "PG_PASSWORD=$" .env 2>/dev/null; then + echo "" + echo "⚠ 请先在 .env 中设置 PostgreSQL 密码:" + echo " 1. 设置数据库密码: sudo -u postgres psql -c \"ALTER USER postgres PASSWORD 'your_password';\"" + echo " 2. 在 .env 中配置: PG_PASSWORD=your_password" + echo "" + echo "配置完成后,运行: python cli.py init" +else + python cli.py init + echo "✓ 数据库初始化完成" +fi + +echo "" +echo "=========================================" +echo " 安装完成!" +echo "=========================================" +echo "" +echo "下一步:" +echo "1. 编辑 backend/.env 文件,配置数据库密码和其他选项" +echo "2. 如果未初始化数据库,运行: python cli.py init" +echo "3. 启动服务: python main.py" +echo "4. 浏览器访问: http://localhost:8000" +echo "5. 默认管理员: admin / admin123 (首次登录后务必修改密码)" +echo "6. 测试功能: python test_core_features.py" +echo "" +echo "详细文档:" +echo "- 升级指南: backend/UPGRADE_GUIDE.md" +echo "- 配置说明: backend/ENV_CONFIG.md" +echo "" diff --git a/backend/limit_analysis.py b/backend/limit_analysis.py index 7e45bf4..9deb3b9 100644 --- a/backend/limit_analysis.py +++ b/backend/limit_analysis.py @@ -1,10 +1,4 @@ -"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行。 - -功能: -1. 连板股追踪器 -2. 炸板率统计 -3. 涨停敢死队排行 -""" +"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行、炸板走势统计、涨停原因分类。""" import datetime as dt from typing import List, Dict, Any, Optional from collections import defaultdict, Counter @@ -12,7 +6,24 @@ import numpy as np from sqlalchemy import select, and_, func, desc from db import get_session -from models import DailyQuote, StockMetric +from models import DailyQuote, StockMetric, DragonTiger + +try: + import akshare as ak + AK_OK = True +except Exception: + ak = None + AK_OK = False + +# 涨停原因关键词分类 +LIMIT_REASON_MAP = { + "题材": ["概念", "题材", "热点", "风口", "赛道"], + "业绩": ["业绩", "净利润", "营收", "盈利", "超预期", "预增", "扭亏"], + "政策": ["政策", "补贴", "利好", "支持", "规划", "国家", "工信部", "发改委"], + "技术突破": ["突破", "新高", "均线", "金叉", "放量"], + "重组并购": ["重组", "并购", "收购", "合并", "入股"], + "情绪": ["跟风", "连板", "情绪", "氛围", "涨停潮"], +} def get_limit_stocks(date: Optional[dt.date] = None, limit_type: str = "up") -> Dict[str, Any]: @@ -293,6 +304,214 @@ def analyze_limit_break_rate(days: int = 60) -> Dict[str, Any]: } +def get_consecutive_calendar(days: int = 60) -> Dict[str, Any]: + """连板日历:记录每只股票的连板历史,分析几进几出规律""" + with get_session() as s: + latest_date = s.execute(select(func.max(DailyQuote.date))).scalar() + if not latest_date: + return {"ok": False, "msg": "暂无数据"} + start_date = latest_date - dt.timedelta(days=days) + quotes = s.execute( + select(DailyQuote) + .where(DailyQuote.date >= start_date) + .order_by(DailyQuote.code, DailyQuote.date) + ).scalars().all() + + stock_data = defaultdict(list) + for q in quotes: + stock_data[q.code].append(q) + + all_streaks = [] + current_streaks = [] + + for code, data in stock_data.items(): + data_sorted = sorted(data, key=lambda x: x.date) + name = data_sorted[-1].name + streaks = [] + current = [] + + for q in data_sorted: + if q.open == 0: + if len(current) >= 2: + streaks.append(current) + current = [] + continue + pct = (float(q.close) - float(q.open)) / float(q.open) * 100 + if pct >= 9.8: + current.append(q) + else: + if len(current) >= 2: + streaks.append(current) + current = [] + + if len(current) >= 2: + current_streaks.append({ + "code": code, "name": name, + "days": len(current), + "start_date": current[0].date.isoformat(), + "latest_date": current[-1].date.isoformat(), + "latest_close": float(current[-1].close) + }) + for streak in streaks: + all_streaks.append({ + "code": code, "name": name, + "days": len(streak), + "start_date": streak[0].date.isoformat(), + "end_date": streak[-1].date.isoformat() + }) + + distribution = defaultdict(int) + for item in all_streaks: + distribution[f"{item['days']}板"] += 1 + + current_streaks.sort(key=lambda x: x["days"], reverse=True) + return { + "ok": True, + "date_range": f"{start_date.isoformat()} ~ {latest_date.isoformat()}", + "current_streaks": current_streaks[:30], + "streak_distribution": dict(distribution), + "total_streaks": len(all_streaks) + } + + +def analyze_post_break_performance(days: int = 90) -> Dict[str, Any]: + """炸板后走势统计:炸板后 1/3/5 日表现概率分布""" + with get_session() as s: + latest_date = s.execute(select(func.max(DailyQuote.date))).scalar() + if not latest_date: + return {"ok": False, "msg": "暂无数据"} + + start_date = latest_date - dt.timedelta(days=days + 10) + quotes = s.execute( + select(DailyQuote) + .where(DailyQuote.date >= start_date) + .order_by(DailyQuote.code, DailyQuote.date) + ).scalars().all() + + stock_data = defaultdict(dict) + for q in quotes: + stock_data[q.code][q.date] = q + + # 炸板 = 当日一度涨停但收盘时未涨停(用开高低收近似判断) + # 简化:前一日涨停,次日收盘未涨停视为炸板 + perf_1d, perf_3d, perf_5d = [], [], [] + + for code, date_data in stock_data.items(): + dates = sorted(date_data.keys()) + for i in range(len(dates) - 5): + today = dates[i] + q_today = date_data[today] + if q_today.open == 0: + continue + pct_today = (float(q_today.close) - float(q_today.open)) / float(q_today.open) * 100 + # 判断昨日涨停今日炸板(开高但收盘低) + if i == 0: + continue + yesterday = dates[i - 1] + q_yest = date_data[yesterday] + if q_yest.open == 0: + continue + pct_yest = (float(q_yest.close) - float(q_yest.open)) / float(q_yest.open) * 100 + + # 昨日涨停,今日未涨停(炸板) + if pct_yest >= 9.8 and pct_today < 9.8: + base = float(q_today.close) + # 后续 1/3/5 日表现 + for horizon, perf_list in [(1, perf_1d), (3, perf_3d), (5, perf_5d)]: + if i + horizon < len(dates): + future = dates[i + horizon] + q_future = date_data[future] + ret = (float(q_future.close) - base) / base * 100 + perf_list.append(round(ret, 2)) + + def summarize(perfs): + if not perfs: + return {} + arr = np.array(perfs) + return { + "samples": len(perfs), + "avg_ret": round(float(arr.mean()), 2), + "win_rate": round(float((arr > 0).mean() * 100), 1), + "p25": round(float(np.percentile(arr, 25)), 2), + "median": round(float(np.median(arr)), 2), + "p75": round(float(np.percentile(arr, 75)), 2), + } + + return { + "ok": True, + "days": days, + "after_1d": summarize(perf_1d), + "after_3d": summarize(perf_3d), + "after_5d": summarize(perf_5d), + "conclusion": ( + f"炸板后样本 {len(perf_1d)} 条," + f"次日平均收益 {summarize(perf_1d).get('avg_ret', 0)}%," + f"次日上涨概率 {summarize(perf_1d).get('win_rate', 0)}%" + ) if perf_1d else "样本不足" + } + + +def classify_limit_reasons(date: Optional[dt.date] = None) -> Dict[str, Any]: + """涨停原因分类:情绪、题材、业绩、技术突破等""" + with get_session() as s: + if date is None: + date = s.execute(select(func.max(DragonTiger.date))).scalar() + if not date: + return {"ok": False, "msg": "暂无龙虎榜数据,请先入库"} + + lhb_rows = s.execute( + select(DragonTiger).where(DragonTiger.date == date) + ).scalars().all() + + # 尝试从 AkShare 获取当日涨停原因 + reason_data = {} + if AK_OK: + try: + df = ak.stock_zt_pool_em(date=date.strftime("%Y%m%d")) + if df is not None and not df.empty: + for _, r in df.iterrows(): + code = str(r.get("代码", "")) + reason = str(r.get("涨停原因类别", "") or r.get("上榜原因", "")) + reason_data[code] = reason + except Exception: + pass + + # 合并龙虎榜原因 + for row in lhb_rows: + if row.code not in reason_data: + reason_data[row.code] = row.reason + + # 分类 + classified = defaultdict(list) + for code, reason in reason_data.items(): + matched = False + for category, keywords in LIMIT_REASON_MAP.items(): + if any(kw in reason for kw in keywords): + classified[category].append({"code": code, "reason": reason}) + matched = True + break + if not matched: + classified["其他"].append({"code": code, "reason": reason}) + + total = sum(len(v) for v in classified.values()) + summary = [ + { + "category": cat, + "count": len(items), + "pct": round(len(items) / total * 100, 1) if total > 0 else 0, + "stocks": items[:10] + } + for cat, items in sorted(classified.items(), key=lambda x: len(x[1]), reverse=True) + ] + + return { + "ok": True, + "date": date.isoformat(), + "total": total, + "categories": summary + } + + def get_limit_squad_rankings(days: int = 30, min_limits: int = 5) -> Dict[str, Any]: """涨停敢死队排行 diff --git a/backend/main.py b/backend/main.py index d28e351..00920f1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -9,14 +9,30 @@ import datetime as dt from contextlib import asynccontextmanager from typing import List, Dict, Any, Optional -from fastapi import FastAPI, Query +from fastapi import FastAPI, Query, Depends from fastapi.middleware.cors import CORSMiddleware +from fastapi.exceptions import RequestValidationError +from sqlalchemy.exc import SQLAlchemyError from fastapi.staticfiles import StaticFiles from sqlalchemy import select, func, desc from pydantic import BaseModel import akshare_service as svc +import redis_cache +from redis_cache import cache +import auth +from auth import get_current_user, require_auth, require_admin +import init_auth +import exceptions +from exceptions import ( + BusinessException, + DataSourceException, + business_exception_handler, + validation_exception_handler, + sqlalchemy_exception_handler, + general_exception_handler +) import config import scheduler import backtest as bt @@ -37,6 +53,11 @@ import sentiment_monitor as sentiment import event_driven as events import financial_analysis as fin import limit_analysis as limit_up +import watchlist_manager as wl +import position_cost as pc +import trade_calendar as cal +import data_manager as dm +import paper_trading as paper from db import init_db, get_session from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily, SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade, @@ -47,8 +68,11 @@ from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily, async def lifespan(app: FastAPI): try: init_db() + init_auth.init_default_admin() + wl.init_default_groups() + paper.ensure_default_account() scheduler.start_scheduler() - print("[startup] db + scheduler ready") + print("[startup] db + scheduler + auth ready") except Exception as e: print("[startup] WARN:", repr(e)[:160]) yield @@ -56,6 +80,12 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan) +# 注册异常处理器 +app.add_exception_handler(BusinessException, business_exception_handler) +app.add_exception_handler(RequestValidationError, validation_exception_handler) +app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler) +app.add_exception_handler(Exception, general_exception_handler) + app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -84,10 +114,113 @@ def save_watch(symbols): json.dump(symbols, f, ensure_ascii=False) +# ============ 认证相关 API ============ +class LoginRequest(BaseModel): + username: str + password: str + +@app.post("/api/auth/login") +def login(req: LoginRequest, db = Depends(get_session)): + """用户登录""" + user = auth.authenticate_user(db, req.username, req.password) + if not user: + raise exceptions.AuthException("用户名或密码错误") + + access_token = auth.create_access_token(data={"sub": user.username}) + return { + "ok": True, + "access_token": access_token, + "token_type": "bearer", + "username": user.username, + "is_admin": user.is_admin + } + +@app.get("/api/auth/me") +def get_me(current_user = Depends(require_auth)): + """获取当前用户信息""" + return { + "ok": True, + "username": current_user.username, + "is_admin": current_user.is_admin + } + +class ChangePasswordRequest(BaseModel): + old_password: str + new_password: str + +@app.post("/api/auth/change-password") +def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)): + """修改密码""" + if not auth.verify_password(req.old_password, current_user.hashed_password): + raise exceptions.AuthException("原密码错误") + + current_user.hashed_password = auth.get_password_hash(req.new_password) + db.commit() + return {"ok": True, "msg": "密码修改成功"} + +# ============ 用户管理 ============ +class CreateUserRequest(BaseModel): + username: str + password: str + is_admin: bool = False + +from models import User as UserModel + +@app.get("/api/users") +def list_users(current_user = Depends(require_admin)): + with get_session() as s: + rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all() + return {"ok": True, "users": [{"id": r.id, "username": r.username, + "is_admin": r.is_admin, "is_active": r.is_active, + "created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]} + +@app.post("/api/users") +def create_user(req: CreateUserRequest, current_user = Depends(require_admin)): + with get_session() as s: + if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none(): + return {"ok": False, "msg": "用户名已存在"} + user = UserModel(username=req.username, + hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin) + s.add(user); s.commit() + return {"ok": True, "id": user.id} + +@app.delete("/api/users/{uid}") +def delete_user(uid: int, current_user = Depends(require_admin)): + if current_user.id == uid: + return {"ok": False, "msg": "不能删除自己"} + with get_session() as s: + u = s.get(UserModel, uid) + if u: s.delete(u); s.commit() + return {"ok": True} + +@app.put("/api/users/{uid}/toggle_admin") +def toggle_admin(uid: int, current_user = Depends(require_admin)): + if current_user.id == uid: + return {"ok": False, "msg": "不能修改自己的权限"} + with get_session() as s: + u = s.get(UserModel, uid) + if not u: return {"ok": False, "msg": "用户不存在"} + u.is_admin = not u.is_admin; s.commit() + return {"ok": True, "is_admin": u.is_admin} + +@app.put("/api/users/{uid}/reset_password") +def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)): + with get_session() as s: + u = s.get(UserModel, uid) + if not u: return {"ok": False, "msg": "用户不存在"} + u.hashed_password = auth.get_password_hash(req.new_password); s.commit() + return {"ok": True} + + # ============ API ============ @app.get("/api/health") def health(): - return {"ok": True, "akshare": svc.AK_OK} + return { + "ok": True, + "akshare": svc.AK_OK, + "redis": cache.enabled, + "auth": True + } @app.get("/api/indices") @@ -106,10 +239,76 @@ def sentiment(): @app.get("/api/treemap") -def treemap(mode: str = Query("sector")): +def treemap(mode: str = Query("sector"), date: str = Query(None)): + if mode == "sector" and date: + # 从数据库读历史板块数据 + try: + target = dt.date.fromisoformat(date) + except Exception: + return svc.get_treemap(mode) + with get_session() as s: + rows = s.execute( + select(SectorDaily).where(SectorDaily.date == target) +.order_by(SectorDaily.pct.desc()) + ).scalars().all() + if not rows: + # 找最近有数据的日期 + latest = s.execute(select(func.max(SectorDaily.date))).scalar() + if latest: + rows = s.execute( + select(SectorDaily).where(SectorDaily.date == latest) + .order_by(SectorDaily.pct.desc()) + ).scalars().all() + target = latest + if rows: + items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows] + return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items} return svc.get_treemap(mode) +@app.get("/api/treemap/us") +def treemap_us(): + return svc.get_us_treemap() + + +@app.get("/api/treemap/hk") +def treemap_hk(): + return svc.get_hk_treemap() + + +@app.get("/api/treemap/sector_stocks") +def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)): + return svc.get_sector_stocks(name, limit) + + +@app.get("/api/treemap/all_leaders") +def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)): + from models import SectorLeader + # 优先从数据库读 + with get_session() as s: + target = None + if date: + try: target = dt.date.fromisoformat(date) + except Exception: pass + if not target: + target = s.execute(select(func.max(SectorLeader.date))).scalar() + if target: + rows = s.execute( + select(SectorLeader).where(SectorLeader.date == target) + .order_by(SectorLeader.sector, SectorLeader.rank) + ).scalars().all() + if rows: + sectors = {} + for r in rows: + sectors.setdefault(r.sector, []).append({ + "code": r.code, "name": r.name, "pct": r.pct, + "price": r.price, "amount": r.amount + }) + return {"source": "db", "date": target.isoformat(), "sectors": sectors} + # 降级到实时 + return svc.get_all_sector_leaders(top_n) + + @app.get("/api/fundflow") def fundflow(): return svc.get_fund_flow() @@ -151,9 +350,105 @@ def watch_del(code: str): return {"ok": True, "list": w} +# ============ 自选股分组管理 ============ +class CreateGroupRequest(BaseModel): + name: str + description: str = "" + color: str = "blue" + +@app.get("/api/watchlist/groups") +def list_groups(): + """获取所有分组""" + return {"ok": True, "groups": wl.get_all_groups()} + +@app.post("/api/watchlist/groups") +def create_group(req: CreateGroupRequest): + """创建新分组""" + return wl.create_group(req.name, req.description, req.color) + +class UpdateGroupRequest(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + color: Optional[str] = None + +@app.put("/api/watchlist/groups/{group_id}") +def update_group(group_id: int, req: UpdateGroupRequest): + """更新分组信息""" + return wl.update_group(group_id, req.name, req.description, req.color) + +@app.delete("/api/watchlist/groups/{group_id}") +def delete_group(group_id: int): + """删除分组""" + return wl.delete_group(group_id) + +class ReorderGroupsRequest(BaseModel): + group_ids: List[int] + +@app.post("/api/watchlist/groups/reorder") +def reorder_groups(req: ReorderGroupsRequest): + """重新排序分组""" + return wl.reorder_groups(req.group_ids) + +@app.get("/api/watchlist/groups/{group_id}/stocks") +def get_group_stocks(group_id: int, with_quotes: bool = Query(True)): + """获取分组内的股票""" + return wl.get_group_stocks(group_id, with_quotes) + +class AddStockRequest(BaseModel): + code: str + note: str = "" + +@app.post("/api/watchlist/groups/{group_id}/stocks") +def add_stock_to_group(group_id: int, req: AddStockRequest): + """添加股票到分组""" + return wl.add_stock_to_group(group_id, req.code, req.note) + +@app.delete("/api/watchlist/stocks/{item_id}") +def remove_stock(item_id: int): + """从分组中移除股票""" + return wl.remove_stock_from_group(item_id) + +class MoveStockRequest(BaseModel): + target_group_id: int + +@app.post("/api/watchlist/stocks/{item_id}/move") +def move_stock(item_id: int, req: MoveStockRequest): + """移动股票到另一个分组""" + return wl.move_stock_to_group(item_id, req.target_group_id) + +class BatchAddRequest(BaseModel): + codes: List[str] + +@app.post("/api/watchlist/groups/{group_id}/stocks/batch") +def batch_add_stocks(group_id: int, req: BatchAddRequest): + """批量添加股票""" + return wl.batch_add_stocks(group_id, req.codes) + +class UpdateNoteRequest(BaseModel): + note: str + +@app.put("/api/watchlist/stocks/{item_id}/note") +def update_stock_note(item_id: int, req: UpdateNoteRequest): + """更新股票备注""" + return wl.update_stock_note(item_id, req.note) + +class ReorderStocksRequest(BaseModel): + item_ids: List[int] + +@app.post("/api/watchlist/stocks/reorder") +def reorder_stocks(req: ReorderStocksRequest): + """重新排序股票""" + return wl.reorder_stocks(req.item_ids) + +@app.get("/api/watchlist/search") +def search_stocks(keyword: str = Query(..., min_length=1)): + """跨分组搜索股票""" + return {"ok": True, "results": wl.search_stocks_across_groups(keyword)} + + # ============ 数据中台 ============ @app.get("/api/admin/status") -def admin_status(): +def admin_status(current_user = Depends(require_admin)): counts, last_dates = {}, {} with get_session() as s: for label, model in [("securities", Security), ("quotes_daily", DailyQuote), @@ -175,14 +470,14 @@ def admin_status(): @app.post("/api/admin/ingest") -def admin_ingest(): +def admin_ingest(current_user = Depends(require_admin)): if scheduler.is_running(): return {"started": False, "msg": "已有入库任务在执行"} return scheduler.trigger_async() @app.post("/api/admin/ingest_all") -def admin_ingest_all(): +def admin_ingest_all(current_user = Depends(require_admin)): return scheduler.trigger_all_async() @@ -488,6 +783,16 @@ def ai_today(): return ai.today_strategy() +@app.get("/api/ai/trend_analysis") +def ai_trend_analysis( + symbol: str = Query(...), + date: str = Query(""), + period: str = Query("daily") +): + """走势分析:右键K线条形时调用,分析暴涨/暴跌原因""" + return ai.trend_analysis(symbol, date, period) + + # ============ 可回溯:信号历史胜率 + 实测准确率 ============ @app.get("/api/ai/signal_stats") def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)): @@ -582,6 +887,144 @@ def portfolio_equity(): return pf.equity_curve() +# ============ 持仓成本可视化增强 ============ +@app.get("/api/portfolio/cost_line/{code}") +def get_cost_line(code: str): + """获取个股持仓成本线(用于K线图标注)""" + return pc.get_position_cost_lines(code) + +@app.get("/api/portfolio/cost_distribution") +def get_cost_distribution(): + """获取持仓成本分布(盈亏区间图)""" + return pc.get_position_cost_distribution() + +class EstimateCostRequest(BaseModel): + code: str + price: float + qty: int + side: str = "buy" + +@app.post("/api/portfolio/estimate_cost") +def estimate_cost(req: EstimateCostRequest): + """估算交易成本(下单前预估)""" + return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side) + +@app.get("/api/portfolio/cost_breakdown/{code}") +def get_cost_breakdown(code: str): + """获取持仓的详细成本拆解""" + return pc.get_cost_breakdown_for_position(code) + + +# ============ 交易日历与关键事件 ============ +@app.get("/api/calendar/events") +def calendar_events(days: int = Query(30, ge=7, le=90)): + """获取所有即将到来的关键事件(综合视图)""" + return cal.get_all_upcoming_events(days) + +@app.get("/api/calendar/dividends") +def calendar_dividends(days: int = Query(30, ge=7, le=90)): + """除权除息日历(持仓股优先)""" + return cal.get_upcoming_dividends(days) + +@app.get("/api/calendar/unlock") +def calendar_unlock(days: int = Query(90, ge=7, le=180)): + """限售解禁日历""" + return cal.get_unlock_calendar(days) + +@app.get("/api/calendar/earnings") +def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)): + """财报披露日历""" + return cal.get_earnings_calendar(days, holding_only) + +@app.post("/api/calendar/check_alerts") +def calendar_check_alerts(current_user = Depends(require_admin)): + """手动触发日历事件预警推送""" + return cal.check_and_push_calendar_alerts() + + +# ============ 数据修正与回填增强 ============ +class UpdateQuoteRequest(BaseModel): + open: Optional[float] = None + high: Optional[float] = None + low: Optional[float] = None + close: Optional[float] = None + volume: Optional[int] = None + amount: Optional[float] = None + +@app.delete("/api/data/quote/{code}/{date}") +def delete_quote(code: str, date: str, current_user = Depends(require_admin)): + """删除指定股票指定日期的日线""" + return dm.delete_quote(code, date) + +@app.put("/api/data/quote/{code}/{date}") +def update_quote(code: str, date: str, req: UpdateQuoteRequest, + current_user = Depends(require_admin)): + """修正指定日线数据""" + return dm.update_quote(code, date, req.model_dump(exclude_none=True)) + +class DeleteRangeRequest(BaseModel): + start: str + end: str + +@app.delete("/api/data/quotes/{code}/range") +def delete_quotes_range(code: str, req: DeleteRangeRequest, + current_user = Depends(require_admin)): + """删除指定股票日期范围内的日线数据""" + return dm.delete_quotes_range(code, req.start, req.end) + +@app.post("/api/data/refetch/{code}") +def refetch_quote(code: str, days: int = Query(60, ge=5, le=500), + current_user = Depends(require_admin)): + """重新抓取指定股票日线(覆盖更新)""" + return dm.refetch_quote(code, days) + +@app.get("/api/data/integrity") +def check_integrity(days: int = Query(30, ge=7, le=90), + current_user = Depends(require_admin)): + """数据完整性检查""" + return dm.check_data_integrity(days=days) + +@app.post("/api/data/auto_fix") +def auto_fix_missing(limit: int = Query(50, ge=10, le=200), + current_user = Depends(require_admin)): + """自动补齐缺失数据""" + t = __import__("threading").Thread( + target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True + ) + t.start() + return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"} + +@app.get("/api/data/refill_progress") +def refill_progress(task_id: str = Query("default")): + """获取回填进度""" + return dm.get_refill_progress(task_id) + +@app.post("/api/data/refill_resume") +def refill_resume(days: int = Query(250, ge=30, le=1000), +task_id: str = Query("default"), + current_user = Depends(require_admin)): + """带断点续传的全市场回填(后台执行)""" + import threading + t = threading.Thread( + target=dm.start_refill_with_resume, + kwargs={"days": days, "task_id": task_id}, + daemon=True + ) + t.start() + return {"ok": True, "msg": f"已启动断点续传回填,天数={days},任务ID={task_id}"} + +@app.delete("/api/data/refill_progress") +def clear_refill_progress(task_id: str = Query("default"), + current_user = Depends(require_admin)): + """清除回填进度(从头开始)""" + return dm.clear_refill_progress(task_id) + +@app.get("/api/data/quality_report") +def data_quality_report(current_user = Depends(require_admin)): + """数据质量报告""" + return dm.get_data_quality_report() + + @app.get("/api/portfolio/attribution") def portfolio_attribution(): """持仓归因分析""" @@ -761,6 +1204,22 @@ def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5, """涨停敢死队排行""" return limit_up.get_limit_squad_rankings(days, min_limits) +@app.get("/api/limit/consecutive_calendar") +def consecutive_calendar(days: int = Query(60, ge=20, le=120)): + """连板日历:记录连板历史,分析几进几出规律""" + return limit_up.get_consecutive_calendar(days) + +@app.get("/api/limit/post_break") +def post_break_performance(days: int = Query(90, ge=30, le=180)): + """炸板后 1/3/5 日走势统计""" + return limit_up.analyze_post_break_performance(days) + +@app.get("/api/limit/reasons") +def limit_reasons(date: Optional[str] = None): + """涨停原因分类(情绪/题材/业绩/政策等)""" + d = dt.date.fromisoformat(date) if date else None + return limit_up.classify_limit_reasons(d) + # ============ 推送通知 ============ @app.get("/api/notify/status") @@ -1142,6 +1601,51 @@ def delete_selector_alert(aid: int): return {"ok": True} +# ============ 模拟盘 ============ + +class PaperAccountIn(BaseModel): + name: str + initial_cash: float = 1_000_000.0 + + +@app.get("/api/paper/accounts") +def paper_list_accounts(): + return {"ok": True, "accounts": paper.list_accounts()} + + +@app.post("/api/paper/accounts") +def paper_create_account(req: PaperAccountIn): + return paper.create_account(req.name, req.initial_cash) + + +@app.post("/api/paper/accounts/{account_id}/reset") +def paper_reset_account(account_id: int, initial_cash: Optional[float] = None): + return paper.reset_account(account_id, initial_cash) + + +class PaperOrderIn(BaseModel): + code: str + side: str # buy / sell + qty: int + price: Optional[float] = None + reason: str = "" + + +@app.post("/api/paper/accounts/{account_id}/order") +def paper_place_order(account_id: int, req: PaperOrderIn): + return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason) + + +@app.get("/api/paper/accounts/{account_id}/portfolio") +def paper_get_portfolio(account_id: int): + return paper.get_portfolio(account_id) + + +@app.get("/api/paper/accounts/{account_id}/trades") +def paper_get_trades(account_id: int, limit: int = Query(100, le=500)): + return {"ok": True, "trades": paper.get_trades(account_id, limit)} + + # ============ 静态前端 ============ FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype") if os.path.isdir(FRONTEND_DIR): diff --git a/backend/models.py b/backend/models.py index c94affc..aef5d18 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,4 +1,4 @@ -"""数据中台 ORM 模型(SQLAlchemy 2.0)。""" +"""数据中台 ORM 模型(SQLAlchemy 2.0)。""" from __future__ import annotations import datetime as dt @@ -64,6 +64,21 @@ class SectorDaily(Base): leader: Mapped[str] = mapped_column(String(40), default="") +class SectorLeader(Base): + """板块每日龙头股(前5,按成交额)。""" + __tablename__ = "sector_leaders" + __table_args__ = (UniqueConstraint("date", "sector", "code", name="uq_sector_leader"),) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + date: Mapped[dt.date] = mapped_column(Date, index=True) + sector: Mapped[str] = mapped_column(String(40), index=True) + code: Mapped[str] = mapped_column(String(12)) + name: Mapped[str] = mapped_column(String(40)) + pct: Mapped[float] = mapped_column(Float, default=0.0) + price: Mapped[float] = mapped_column(Float, default=0.0) + amount: Mapped[float] = mapped_column(Float, default=0.0) + rank: Mapped[int] = mapped_column(Integer, default=0) + + class FundFlowDaily(Base): """行业资金流每日快照。""" __tablename__ = "fund_flow_daily" @@ -349,3 +364,68 @@ class IntradayEvent(Base): description: Mapped[str] = mapped_column(String(200), default="") detected_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now(), index=True) notified: Mapped[bool] = mapped_column(default=False) + + +class PaperAccount(Base): + """模拟盘账户。""" + __tablename__ = "paper_accounts" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(50), default="默认模拟盘") + initial_cash: Mapped[float] = mapped_column(Float, default=1_000_000.0) + cash: Mapped[float] = mapped_column(Float, default=1_000_000.0) + is_active: Mapped[bool] = mapped_column(default=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now()) + + +class PaperTrade(Base): + """模拟盘交易记录。""" + __tablename__ = "paper_trades" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + account_id: Mapped[int] = mapped_column(Integer, index=True, default=1) + date: Mapped[dt.date] = mapped_column(Date, index=True) + code: Mapped[str] = mapped_column(String(12), index=True) + name: Mapped[str] = mapped_column(String(40), default="") + side: Mapped[str] = mapped_column(String(4)) # buy / sell + price: Mapped[float] = mapped_column(Float) + qty: Mapped[int] = mapped_column(Integer) + fee: Mapped[float] = mapped_column(Float, default=0.0) + cash_before: Mapped[float] = mapped_column(Float, default=0.0) + cash_after: Mapped[float] = mapped_column(Float, default=0.0) + reason: Mapped[str] = mapped_column(String(60), default="") + created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now()) + + +class User(Base): + """用户表(用于鉴权)。""" + __tablename__ = "users" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + username: Mapped[str] = mapped_column(String(50), unique=True, index=True) + hashed_password: Mapped[str] = mapped_column(String(100)) + is_admin: Mapped[bool] = mapped_column(default=False) + is_active: Mapped[bool] = mapped_column(default=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now()) + + +class WatchlistGroup(Base): + """自选股分组。""" + __tablename__ = "watchlist_groups" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(50)) + description: Mapped[str] = mapped_column(String(200), default="") + color: Mapped[str] = mapped_column(String(20), default="blue") # 分组颜色标识 + sort_order: Mapped[int] = mapped_column(Integer, default=0) + is_default: Mapped[bool] = mapped_column(default=False) + created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now()) + + +class WatchlistItem(Base): + """自选股项目。""" + __tablename__ = "watchlist_items" + __table_args__ = (UniqueConstraint("group_id", "code", name="uq_watchlist_group_code"),) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + group_id: Mapped[int] = mapped_column(Integer, index=True) + code: Mapped[str] = mapped_column(String(12), index=True) + name: Mapped[str] = mapped_column(String(40), default="") + sort_order: Mapped[int] = mapped_column(Integer, default=0) + note: Mapped[str] = mapped_column(String(200), default="") # 个股备注 + added_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now()) diff --git a/backend/paper_trading.py b/backend/paper_trading.py new file mode 100644 index 0000000..5fd8bcf --- /dev/null +++ b/backend/paper_trading.py @@ -0,0 +1,222 @@ +"""模拟盘核心逻辑。 + +- 多账户支持(默认账户 id=1) +- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金 +- 持仓计算使用移动加权平均成本法 +""" +from __future__ import annotations + +import datetime as dt +from collections import defaultdict + +from sqlalchemy import select + +from db import get_session +from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote + +DEFAULT_FEE_RATE = 0.0003 + + +def _get_price(code: str) -> float | None: + with get_session() as s: + m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none() + if m: + return float(m) + row = s.execute( + select(DailyQuote.close).where(DailyQuote.code == code) + .order_by(DailyQuote.date.desc()).limit(1) + ).scalar_one_or_none() + return float(row) if row else None + + +# ── 账户管理 ────────────────────────────────────────────── + +def ensure_default_account(): + """确保默认账户(id=1)存在,启动时调用。""" + with get_session() as s: + if not s.get(PaperAccount, 1): + s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0)) + s.commit() + + +def list_accounts() -> list[dict]: + with get_session() as s: + rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all() + return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash, + "cash": round(r.cash, 2), "is_active": r.is_active, + "created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows] + + +def create_account(name: str, initial_cash: float) -> dict: + with get_session() as s: + acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash) + s.add(acc) + s.commit() + return {"ok": True, "id": acc.id} + + +def reset_account(account_id: int, initial_cash: float | None = None) -> dict: + with get_session() as s: + acc = s.get(PaperAccount, account_id) + if not acc: + return {"ok": False, "msg": "账户不存在"} + if initial_cash is not None: + acc.initial_cash = initial_cash + acc.cash = acc.initial_cash + for t in s.execute( + select(PaperTrade).where(PaperTrade.account_id == account_id) + ).scalars(): + s.delete(t) + s.commit() + return {"ok": True, "msg": "账户已重置"} + + +# ── 持仓计算(内部)──────────────────────────────────────── + +def _calc_holdings_in_session(account_id: int, s) -> list[dict]: + trades = s.execute( + select(PaperTrade).where(PaperTrade.account_id == account_id) + .order_by(PaperTrade.date, PaperTrade.id) + ).scalars().all() + pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""}) + for t in trades: + p = pos[t.code] + p["name"] = t.name or p["name"] + if t.side == "buy": + p["cost"] += t.price * t.qty + t.fee + p["qty"] += t.qty + else: + if p["qty"] > 0: + avg = p["cost"] / p["qty"] + qty = min(t.qty, p["qty"]) + p["cost"] -= avg * qty + p["qty"] -= qty + return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]} +for c, v in pos.items() if v["qty"] > 0] + + +# ── 下单 ────────────────────────────────────────────────── + +def place_order(account_id: int, code: str, side: str, qty: int, + price: float | None = None, reason: str = "") -> dict: + if qty <= 0: + return {"ok": False, "msg": "数量必须大于 0"} + if side not in ("buy", "sell"): + return {"ok": False, "msg": "side 只能是 buy 或 sell"} + + exec_price = price or _get_price(code) + if not exec_price: + return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"} + + fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2) + + with get_session() as s: + acc = s.get(PaperAccount, account_id) + if not acc: + return {"ok": False, "msg": "账户不存在"} + + sec = s.get(Security, code) + name = sec.name if sec else code + + if side == "buy": + cost = exec_price * qty + fee + if acc.cash < cost: + return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"} + cash_before = acc.cash + acc.cash -= cost + else: + holdings = _calc_holdings_in_session(account_id, s) + pos = next((h for h in holdings if h["code"] == code), None) + avail = pos["qty"] if pos else 0 + if avail < qty: + return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty} 股"} + cash_before = acc.cash + acc.cash += exec_price * qty - fee + + trade = PaperTrade( + account_id=account_id, + date=dt.date.today(), + code=code, name=name, side=side, + price=exec_price, qty=qty, fee=fee, + cash_before=cash_before, cash_after=acc.cash, + reason=reason, + ) + s.add(trade) + s.commit() + return {"ok": True, "id": trade.id, "price": exec_price, + "fee": fee, "cash_after": round(acc.cash, 2)} + + +# ── 查询接口 ────────────────────────────────────────────── + +def get_portfolio(account_id: int) -> dict: + with get_session() as s: + acc = s.get(PaperAccount, account_id) + if not acc: + return {"ok": False, "msg": "账户不存在"} + cash = acc.cash + initial = acc.initial_cash + holdings_raw = _calc_holdings_in_session(account_id, s) + + codes = [h["code"] for h in holdings_raw] + px: dict[str, float] = {} + if codes: + with get_session() as s: + for m in s.execute( + select(StockMetric).where(StockMetric.code.in_(codes)) + ).scalars(): + px[m.code] = m.close + for c in [c for c in codes if c not in px]: + row = s.execute( + select(DailyQuote.close).where(DailyQuote.code == c) + .order_by(DailyQuote.date.desc()).limit(1) + ).scalar_one_or_none() + if row: + px[c] = float(row) + + holdings, mkt_val = [], 0.0 + for h in holdings_raw: + avg = h["cost"] / h["qty"] if h["qty"] else 0.0 + cur = px.get(h["code"], avg) + mv = cur * h["qty"] + unreal = (cur - avg) * h["qty"] + mkt_val += mv + holdings.append({ + "code": h["code"], "name": h["name"], "qty": h["qty"], + "avg_cost": round(avg, 3), "cur": round(cur, 3), + "market_value": round(mv, 2), + "unrealized": round(unreal, 2), + "unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0, + }) + holdings.sort(key=lambda x: x["unrealized"], reverse=True) + + total_assets = cash + mkt_val + total_pnl = total_assets - initial + return { + "ok": True, + "account_id": account_id, + "summary": { + "initial_cash": round(initial, 2), + "cash": round(cash, 2), + "market_value": round(mkt_val, 2), + "total_assets": round(total_assets, 2), + "total_pnl": round(total_pnl, 2), + "total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0, + "positions": len(holdings), + }, + "holdings": holdings, + } + + +def get_trades(account_id: int, limit: int = 100) -> list[dict]: + with get_session() as s: + rows = s.execute( + select(PaperTrade).where(PaperTrade.account_id == account_id) + .order_by(PaperTrade.id.desc()).limit(limit) + ).scalars().all() + return [{ + "id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name, + "side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee, + "cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2), + "reason": t.reason, + } for t in rows] \ No newline at end of file diff --git a/backend/position_cost.py b/backend/position_cost.py new file mode 100644 index 0000000..b2d3e7f --- /dev/null +++ b/backend/position_cost.py @@ -0,0 +1,347 @@ +"""持仓成本可视化增强""" +import datetime as dt +from typing import Dict, List, Optional +from collections import defaultdict +from sqlalchemy import select, func +from db import get_session +from models import Trade, DailyQuote, StockMetric + +# A股交易成本配置 +COST_CONFIG = { + "stamp_tax": 0.001, # 印花税 0.1%(仅卖出) + "commission_rate": 0.0003, # 佣金费率 0.03% + "commission_min": 5.0, # 最低佣金 5元 + "transfer_fee": 0.00001, # 过户费 0.001%(沪市) +} + +def calculate_trade_cost(price: float, qty: int, side: str, is_sh: bool = True) -> Dict: + """精确计算交易成本 + + Args: + price: 成交价格 + qty: 成交数量 + side: buy/sell + is_sh: 是否沪市(影响过户费) + + Returns: + 成本明细字典 + """ + amount = price * qty + + # 佣金(买卖都有) + commission = max(amount * COST_CONFIG["commission_rate"], COST_CONFIG["commission_min"]) + + # 印花税(仅卖出) + stamp_tax = amount * COST_CONFIG["stamp_tax"] if side == "sell" else 0.0 + + # 过户费(沪市买卖都有,深市无) + transfer_fee = amount * COST_CONFIG["transfer_fee"] if is_sh else 0.0 + + total_cost = commission + stamp_tax + transfer_fee + + return { + "amount": round(amount, 2), + "commission": round(commission, 2), + "stamp_tax": round(stamp_tax, 2), + "transfer_fee": round(transfer_fee, 2), + "total_cost": round(total_cost, 2), + "cost_rate": round(total_cost / amount * 100, 4) if amount > 0 else 0.0 + } + +def get_position_cost_lines(code: str) -> Dict: + """获取个股的持仓成本线数据(用于K线图标注) + + Returns: + { + "code": "600519", + "name": "贵州茅台", + "current_position": { + "qty": 100, + "avg_cost": 1680.5, + "total_cost": 168050.0, + "trades_count": 3 + }, + "cost_history": [ + {"date": "2024-01-15", "cost": 1650.0, "qty": 100, "action": "买入"}, + {"date": "2024-02-10", "cost": 1680.5, "qty": 100, "action": "补仓"} + ] + } + """ + with get_session() as s: + trades = s.execute( + select(Trade).where(Trade.code == code) + .order_by(Trade.date, Trade.id) + ).scalars().all() + + if not trades: + return {"ok": False, "msg": "该股票无交易记录"} + + # 计算持仓成本变化 + qty = 0 + cost = 0.0 + cost_history = [] + + for t in trades: + is_sh = t.code.startswith("6") + + if t.side == "buy": + # 买入:加权平均成本 + old_qty = qty + old_cost = cost + + qty += t.qty + cost += t.price * t.qty + t.fee + + avg_cost = cost / qty if qty > 0 else 0 + action = "补仓" if old_qty > 0 else "买入" + + cost_history.append({ + "date": t.date.isoformat(), + "cost": round(avg_cost, 2), + "qty": qty, + "action": action, + "trade_price": t.price, + "trade_qty": t.qty + }) + + else: # sell + if qty <= 0: + continue + + avg_cost = cost / qty + sell_qty = min(t.qty, qty) + + # 卖出:减少持仓 + cost -= avg_cost * sell_qty + qty -= sell_qty + + action = "清仓" if qty == 0 else "减仓" + + cost_history.append({ + "date": t.date.isoformat(), + "cost": round(cost / qty, 2) if qty > 0 else 0, + "qty": qty, + "action": action, + "trade_price": t.price, + "trade_qty": sell_qty, + "pnl": round((t.price - avg_cost) * sell_qty - t.fee, 2) + }) + + # 当前持仓 + current_position = None + if qty > 0: + avg_cost = cost / qty + + # 获取当前价格 + metric = s.execute( + select(StockMetric).where(StockMetric.code == code) + ).scalar_one_or_none() + + current_price = metric.close if metric else avg_cost + + current_position = { + "qty": qty, + "avg_cost": round(avg_cost, 2), + "total_cost": round(cost, 2), + "current_price": round(current_price, 2), + "market_value": round(current_price * qty, 2), + "unrealized_pnl": round((current_price - avg_cost) * qty, 2), + "unrealized_pct": round((current_price / avg_cost - 1) * 100, 2) if avg_cost > 0 else 0, + "trades_count": len([t for t in trades if t.side == "buy"]) + } + + return { + "ok": True, + "code": code, + "name": trades[0].name, + "current_position": current_position, + "cost_history": cost_history + } + +def get_position_cost_distribution() -> Dict: + """获取所有持仓的成本分布(盈亏区间图) + + Returns: + { + "profitable": [...], # 盈利持仓 + "unprofitable": [...], # 亏损持仓 + "breakeven": [...] # 持平持仓 + } + """ + with get_session() as s: + trades = s.execute( + select(Trade).order_by(Trade.date, Trade.id) + ).scalars().all() + + # 计算当前持仓 + pos = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""}) + + for t in trades: + p = pos[t.code] + p["name"] = t.name or p["name"] + + if t.side == "buy": + p["cost"] += t.price * t.qty + t.fee + p["qty"] += t.qty + else: + if p["qty"] > 0: + avg = p["cost"] / p["qty"] + qty = min(t.qty, p["qty"]) + p["cost"] -= avg * qty + p["qty"] -= qty + + # 获取当前价格 + codes = [c for c, v in pos.items() if v["qty"] > 0] + if not codes: + return {"ok": True, "profitable": [], "unprofitable": [], "breakeven": []} + + metrics = s.execute( + select(StockMetric).where(StockMetric.code.in_(codes)) + ).scalars().all() + + price_map = {m.code: m.close for m in metrics} + + # 分类统计 + profitable = [] + unprofitable = [] + breakeven = [] + + for code, p in pos.items(): + if p["qty"] <= 0: + continue + + avg_cost = p["cost"] / p["qty"] + current_price = price_map.get(code, avg_cost) + unrealized = (current_price - avg_cost) * p["qty"] + unrealized_pct = (current_price / avg_cost - 1) * 100 if avg_cost > 0 else 0 + + item = { + "code": code, + "name": p["name"], + "qty": p["qty"], + "avg_cost": round(avg_cost, 2), + "current_price": round(current_price, 2), + "market_value": round(current_price * p["qty"], 2), + "cost_value": round(p["cost"], 2), + "unrealized": round(unrealized, 2), + "unrealized_pct": round(unrealized_pct, 2) + } + + if unrealized_pct > 0.5: + profitable.append(item) + elif unrealized_pct < -0.5: + unprofitable.append(item) + else: + breakeven.append(item) + + # 排序 + profitable.sort(key=lambda x: x["unrealized"], reverse=True) + unprofitable.sort(key=lambda x: x["unrealized"]) + + return { + "ok": True, + "profitable": profitable, + "unprofitable": unprofitable, + "breakeven": breakeven, + "summary": { + "total_positions": len(codes), + "profitable_count": len(profitable), + "unprofitable_count": len(unprofitable), + "breakeven_count": len(breakeven), + "win_rate": round(len(profitable) / len(codes) * 100, 1) if codes else 0 + } + } + +def estimate_trade_cost(code: str, price: float, qty: int, side: str) -> Dict: + """估算交易成本(下单前预估) + + Args: + code: 股票代码 + price: 预计成交价 + qty: 交易数量 + side: buy/sell + + Returns: + 成本明细和净值 + """ + is_sh = code.startswith("6") + cost_detail = calculate_trade_cost(price, qty, side, is_sh) + + if side == "buy": + net_amount = cost_detail["amount"] + cost_detail["total_cost"] + msg = f"买入需支付: {round(net_amount, 2)} 元(含交易成本 {cost_detail['total_cost']} 元)" + else: + net_amount = cost_detail["amount"] - cost_detail["total_cost"] + msg = f"卖出可获得: {round(net_amount, 2)} 元(扣除交易成本 {cost_detail['total_cost']} 元)" + + return { + "ok": True, + "code": code, + "price": price, + "qty": qty, + "side": side, + "cost_detail": cost_detail, + "net_amount": round(net_amount, 2), + "message": msg + } + +def get_cost_breakdown_for_position(code: str) -> Dict: + """获取持仓的详细成本拆解 + + Returns: + { + "total_cost": 168500.0, + "purchase_amount": 168050.0, # 实际买入金额 + "commission": 350.0, # 累计佣金 + "stamp_tax": 0.0, # 累计印花税(买入无) + "transfer_fee": 100.0, # 累计过户费 + "trades": [...] # 每笔交易明细 + } + """ + with get_session() as s: + trades = s.execute( + select(Trade).where(Trade.code == code, Trade.side == "buy") + .order_by(Trade.date) + ).scalars().all() + + if not trades: + return {"ok": False, "msg": "该股票无买入记录"} + + is_sh = code.startswith("6") + + total_purchase = 0.0 + total_commission = 0.0 + total_stamp = 0.0 + total_transfer = 0.0 + trade_details = [] + + for t in trades: + cost = calculate_trade_cost(t.price, t.qty, "buy", is_sh) + + total_purchase += cost["amount"] + total_commission += cost["commission"] + total_stamp += cost["stamp_tax"] + total_transfer += cost["transfer_fee"] + + trade_details.append({ + "date": t.date.isoformat(), + "price": t.price, + "qty": t.qty, + "amount": cost["amount"], + "cost_detail": cost + }) + + total_cost = total_purchase + total_commission + total_stamp + total_transfer + + return { + "ok": True, + "code": code, + "name": trades[0].name, + "total_cost": round(total_cost, 2), + "purchase_amount": round(total_purchase, 2), + "commission": round(total_commission, 2), + "stamp_tax": round(total_stamp, 2), + "transfer_fee": round(total_transfer, 2), + "cost_rate": round((total_cost - total_purchase) / total_purchase * 100, 4) if total_purchase > 0 else 0, + "trades": trade_details + } diff --git a/backend/redis_cache.py b/backend/redis_cache.py new file mode 100644 index 0000000..b1a62a1 --- /dev/null +++ b/backend/redis_cache.py @@ -0,0 +1,88 @@ +"""Redis 缓存层,替代内存缓存,支持持久化和跨进程共享。""" +import json +import redis +from typing import Any, Optional +from config import REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD + +class RedisCache: + def __init__(self): + self.client: Optional[redis.Redis] = None + self.enabled = False + self._connect() + + def _connect(self): + """连接 Redis,失败时降级为禁用状态""" + try: + self.client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD if REDIS_PASSWORD else None, + decode_responses=True, + socket_connect_timeout=2, + socket_timeout=2 + ) + self.client.ping() + self.enabled = True + print(f"✓ Redis 已连接: {REDIS_HOST}:{REDIS_PORT}") + except Exception as e: + self.enabled = False + print(f"✗ Redis 连接失败,缓存已禁用: {e}") + + def get(self, key: str) -> Optional[Any]: + """获取缓存,自动反序列化 JSON""" + if not self.enabled: + return None + try: + value = self.client.get(key) + if value: + return json.loads(value) + return None + except Exception as e: + print(f"Redis get error: {e}") + return None + + def set(self, key: str, value: Any, expire: int = 3600): + """设置缓存,自动序列化为 JSON + + Args: + key: 缓存键 + value: 缓存值(可序列化为JSON的对象) + expire: 过期时间(秒),默认1小时 + """ + if not self.enabled: + return False + try: + serialized = json.dumps(value, ensure_ascii=False, default=str) + self.client.setex(key, expire, serialized) + return True + except Exception as e: + print(f"Redis set error: {e}") + return False + + def delete(self, key: str): + """删除缓存""" + if not self.enabled: + return False + try: + self.client.delete(key) + return True + except Exception as e: + print(f"Redis delete error: {e}") + return False + + def clear_pattern(self, pattern: str): + """批量删除匹配模式的键""" + if not self.enabled: + return 0 + try: + keys = self.client.keys(pattern) + if keys: + return self.client.delete(*keys) + return 0 + except Exception as e: + print(f"Redis clear_pattern error: {e}") + return 0 + +# 全局单例 +cache = RedisCache() \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 43fee8f..4aac425 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,3 +8,7 @@ APScheduler>=3.10.4 psycopg2-binary>=2.9.9 jieba>=0.42.1 numpy>=1.26.0 +redis>=5.0.0 +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 +python-multipart>=0.0.9 diff --git a/backend/scheduler.py b/backend/scheduler.py index 6f0f099..cb3e44c 100644 --- a/backend/scheduler.py +++ b/backend/scheduler.py @@ -13,6 +13,7 @@ import alerts import report import signals import intraday_radar +import trade_calendar _scheduler: BackgroundScheduler | None = None _lock = threading.Lock() @@ -134,11 +135,16 @@ def start_scheduler(): _job_signal_stats, CronTrigger(day_of_week="sat", hour=9, minute=0), id="signal_stats", replace_existing=True, misfire_grace_time=7200, ) - # 盘中异动扫描(交易时间每分钟) + # 盘中异动扫描(交易时间每分钟) _scheduler.add_job( _safe_scan_intraday, IntervalTrigger(seconds=60), id="intraday_scan", replace_existing=True, max_instances=1, ) + # 每日早盘前推送日历事件提醒(持仓股除权、解禁、财报等) + _scheduler.add_job( + _job_calendar_alerts, CronTrigger(day_of_week="mon-fri", hour=8, minute=30), + id="calendar_alerts", replace_existing=True, misfire_grace_time=3600, + ) _scheduler.start() return _scheduler @@ -154,7 +160,13 @@ def _safe_scan_intraday(): try: result = intraday_radar.scan_all() if result.get("count", 0) > 0: - # 有新异动时自动推送 intraday_radar.notify_events() except Exception as e: print("[intraday] scan error:", repr(e)[:120]) + + +def _job_calendar_alerts(): + try: + trade_calendar.check_and_push_calendar_alerts() + except Exception as e: + print("[calendar] alert error:", repr(e)[:120]) diff --git a/backend/test_core_features.py b/backend/test_core_features.py new file mode 100644 index 0000000..4e376c0 --- /dev/null +++ b/backend/test_core_features.py @@ -0,0 +1,179 @@ +"""测试三大核心功能的脚本""" +import requests +import time +import sys + +BASE_URL = "http://localhost:8000" + +def test_health(): + """测试健康检查接口""" + print("\n=== 1. 测试健康检查 ===") + try: + resp = requests.get(f"{BASE_URL}/api/health") + data = resp.json() + print(f"状态码: {resp.status_code}") + print(f"响应: {data}") + print(f"✓ AkShare: {'可用' if data.get('akshare') else '不可用'}") + print(f"✓ Redis: {'已连接' if data.get('redis') else '未连接(将使用内存缓存)'}") + print(f"✓ 鉴权: {'已启用' if data.get('auth') else '未启用'}") + return True + except Exception as e: + print(f"✗ 失败: {e}") + return False + +def test_cache_performance(): + """测试 Redis 缓存性能""" + print("\n=== 2. 测试 Redis 缓存性能 ===") + try: + # 第一次请求(缓存未命中) + start = time.time() + resp1 = requests.get(f"{BASE_URL}/api/indices") + time1 = time.time() - start + print(f"第一次请求耗时: {time1:.3f}秒") + + # 第二次请求(应该命中缓存) + start = time.time() + resp2 = requests.get(f"{BASE_URL}/api/indices") + time2 = time.time() - start + print(f"第二次请求耗时: {time2:.3f}秒") + + speedup = time1 / time2 if time2 > 0 else 1 + print(f"性能提升: {speedup:.1f}x") + + if speedup > 2: + print("✓ Redis 缓存生效") + else: + print("⚠ 缓存可能未生效或使用内存缓存") + return True + except Exception as e: + print(f"✗ 失败: {e}") + return False + +def test_auth_login(): + """测试登录功能""" + print("\n=== 3. 测试认证系统 - 登录 ===") + try: + resp = requests.post( + f"{BASE_URL}/api/auth/login", + json={"username": "admin", "password": "admin123"} + ) + print(f"状态码: {resp.status_code}") + data = resp.json() + + if resp.status_code == 200 and data.get("access_token"): + print(f"✓ 登录成功") + print(f" 用户名: {data.get('username')}") + print(f" 管理员: {data.get('is_admin')}") + print(f" Token: {data.get('access_token')[:50]}...") + return data.get("access_token") + else: + print(f"✗ 登录失败: {data}") + return None + except Exception as e: + print(f"✗ 失败: {e}") + return None + +def test_auth_protected(token): + """测试受保护的接口""" + print("\n=== 4. 测试认证系统 - 受保护接口 ===") + + # 测试无 Token 访问 + print("\n4.1 无 Token 访问管理接口:") + try: + resp = requests.get(f"{BASE_URL}/api/admin/status") + print(f"状态码: {resp.status_code}") + if resp.status_code == 401: + print("✓ 正确拦截未认证请求") + else: + print(f"⚠ 预期 401,实际 {resp.status_code}") + except Exception as e: + print(f"✗ 失败: {e}") + + # 测试有 Token 访问 + print("\n4.2 使用 Token 访问管理接口:") + try: + resp = requests.get( + f"{BASE_URL}/api/admin/status", + headers={"Authorization": f"Bearer {token}"} + ) + print(f"状态码: {resp.status_code}") + if resp.status_code == 200: + data = resp.json() + print("✓ Token 认证成功") + print(f" 数据库记录数: {data.get('counts', {})}") + return True + else: + print(f"✗ 认证失败: {resp.json()}") + return False + except Exception as e: + print(f"✗ 失败: {e}") + return False + +def test_exception_handling(): + """测试异常处理""" + print("\n=== 5. 测试统一异常处理 ===") + + # 测试参数验证错误 + print("\n5.1 测试参数验证错误:") + try: + resp = requests.get(f"{BASE_URL}/api/kline?days=invalid") + print(f"状态码: {resp.status_code}") + data = resp.json() + if resp.status_code == 422: + print("✓ 参数验证错误处理正确") + print(f" 错误信息: {data.get('error')}") + else: + print(f"⚠ 预期 422,实际 {resp.status_code}") + except Exception as e: + print(f"✗ 失败: {e}") + + # 测试业务逻辑错误 + print("\n5.2 测试业务逻辑错误:") + try: + resp = requests.get(f"{BASE_URL}/api/backtest?symbol=600519&fast=20&slow=10") + print(f"状态码: {resp.status_code}") + data = resp.json() + if not data.get('ok'): + print("✓ 业务错误处理正确") + print(f" 错误信息: {data.get('msg')}") + except Exception as e: + print(f"✗ 失败: {e}") + + return True + +def main(): + print("="*50) + print("三大核心功能测试") + print("="*50) + + # 检查服务是否运行 + try: + requests.get(f"{BASE_URL}/api/health", timeout=2) + except: + print(f"\n✗ 无法连接到服务: {BASE_URL}") + print("请确保服务已启动: python main.py") + sys.exit(1) + + # 运行测试 + test_health() + test_cache_performance() + token = test_auth_login() + + if token: + test_auth_protected(token) + else: + print("\n⚠ 跳过受保护接口测试(登录失败)") + + test_exception_handling() + + print("\n" + "="*50) + print("测试完成!") + print("="*50) + print("\n下一步:") + print("1. 如果 Redis 显示'未连接',请安装并启动 Redis") + print("2. 如果登录失败,请运行: python cli.py init") + print("3. 登录成功后,务必修改默认密码") + print("4. 生产环境请修改 .env 中的 SECRET_KEY") + +if __name__ == "__main__": + main() diff --git a/backend/trade_calendar.py b/backend/trade_calendar.py new file mode 100644 index 0000000..4ece943 --- /dev/null +++ b/backend/trade_calendar.py @@ -0,0 +1,338 @@ +"""交易日历与关键事件提醒""" +import datetime as dt +from typing import List, Optional, Dict +from sqlalchemy import select, and_ +from db import get_session +from models import Trade, Security, CorporateEvent, AlertEvent +import akshare_service as svc + +try: + import akshare as ak + AK_OK = True +except Exception: + ak = None + AK_OK = False + + +def get_upcoming_dividends(days_ahead: int = 30) -> Dict: + """获取即将到来的除权除息日""" + today = dt.date.today() + end = today + dt.timedelta(days=days_ahead) + + # 获取持仓股票代码 + with get_session() as s: + trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all() + + # 计算当前持仓 + pos = {} + for t in trades: + if t.code not in pos: + pos[t.code] = {"qty": 0, "name": t.name} + if t.side == "buy": + pos[t.code]["qty"] += t.qty + else: + pos[t.code]["qty"] = max(0, pos[t.code]["qty"] - t.qty) + + holding_codes = [c for c, v in pos.items() if v["qty"] > 0] + + events = [] + if AK_OK and holding_codes: + try: + df = ak.stock_zh_a_dividend() + if df is not None and not df.empty: + for _, r in df.iterrows(): + code = str(r.get("代码", "")) + if code not in holding_codes: + continue + ex_date_str = str(r.get("除权除息日", "")) + if not ex_date_str or ex_date_str == "nan": + continue + try: + ex_date = dt.date.fromisoformat(ex_date_str[:10]) + if today <= ex_date <= end: + days_left = (ex_date - today).days + events.append({ + "code": code, + "name": pos[code]["name"], + "event_type": "除权除息", + "event_date": ex_date.isoformat(), + "days_left": days_left, + "detail": f"送股: {r.get('送股', 0)}, 转增: {r.get('转增', 0)}, 派息: {r.get('派息', 0)}", + "is_holding": True, + "urgency": "high" if days_left <= 3 else "medium" + }) + except Exception: + continue + except Exception: + pass + + # 补充从数据库中获取的事件 + with get_session() as s: + db_events = s.execute( + select(CorporateEvent).where( + and_( + CorporateEvent.event_type == "dividend", + CorporateEvent.event_date >= today, + CorporateEvent.event_date <= end + ) + ).order_by(CorporateEvent.event_date) + ).scalars().all() + + for e in db_events: + if any(ev["code"] == e.code for ev in events): + continue + days_left = (e.event_date - today).days + events.append({ + "code": e.code, + "name": e.name, + "event_type": "除权除息", + "event_date": e.event_date.isoformat(), + "days_left": days_left, + "detail": e.description, + "is_holding": e.code in holding_codes, + "urgency": "high" if days_left <= 3 else "medium" + }) + + events.sort(key=lambda x: x["event_date"]) + return {"ok": True, "events": events, "count": len(events)} + + +def get_unlock_calendar(days_ahead: int = 90) -> Dict: + """获取限售解禁日历""" + today = dt.date.today() + end = today + dt.timedelta(days=days_ahead) + + events = [] + if AK_OK: + try: + df = ak.stock_restricted_release_summary_em() + if df is not None and not df.empty: + for _, r in df.head(50).iterrows(): + date_str = str(r.get("解禁日期", "")) + if not date_str or date_str == "nan": + continue + try: + unlock_date = dt.date.fromisoformat(date_str[:10]) + if today <= unlock_date <= end: + amount = float(r.get("解禁数量", 0) or 0) + market_val = float(r.get("解禁市值", 0) or 0) + events.append({ + "code": str(r.get("代码", "")), + "name": str(r.get("名称", "")), + "event_type": "限售解禁", + "event_date": unlock_date.isoformat(), + "days_left": (unlock_date - today).days, + "detail": f"解禁市值: {round(market_val/1e8, 2)}亿", + "amount_billion": round(market_val / 1e8, 2), + "urgency": "high" if market_val >= 10e8 else "medium" + }) + except Exception: + continue + except Exception: + pass + + # 从数据库补充 + with get_session() as s: + db_events = s.execute( + select(CorporateEvent).where( + and_( + CorporateEvent.event_type == "unlock", + CorporateEvent.event_date >= today, + CorporateEvent.event_date <= end + ) + ).order_by(CorporateEvent.event_date) + ).scalars().all() + + for e in db_events: + if any(ev["code"] == e.code for ev in events): + continue + events.append({ + "code": e.code, + "name": e.name, + "event_type": "限售解禁", + "event_date": e.event_date.isoformat(), + "days_left": (e.event_date - today).days, + "detail": e.description, + "amount_billion": e.amount, + "urgency": "high" if e.amount >= 10 else "medium" + }) + + events.sort(key=lambda x: x["event_date"]) + return {"ok": True, "events": events, "count": len(events)} + + +def get_earnings_calendar(days_ahead: int = 30, holding_only: bool = False) -> Dict: + """获取财报披露日历""" + today = dt.date.today() + end = today + dt.timedelta(days=days_ahead) + + # 获取持仓代码 + holding_codes = set() + if holding_only: + with get_session() as s: + trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all() + pos = {} + for t in trades: + if t.code not in pos: + pos[t.code] = 0 + pos[t.code] += t.qty if t.side == "buy" else -t.qty + holding_codes = {c for c, q in pos.items() if q > 0} + + events = [] + if AK_OK: + try: + df = ak.stock_notice_report() + if df is not None and not df.empty: + for _, r in df.head(100).iterrows(): + date_str = str(r.get("公告日期", "")) + code = str(r.get("代码", "")) + if not date_str or date_str == "nan": + continue + if holding_only and code not in holding_codes: + continue + try: + report_date = dt.date.fromisoformat(date_str[:10]) + if today <= report_date <= end: + events.append({ + "code": code, + "name": str(r.get("名称", "")), + "event_type": "财报披露", + "event_date": report_date.isoformat(), + "days_left": (report_date - today).days, + "report_type": str(r.get("公告类型", "")), + "is_holding": code in holding_codes, + "urgency": "high" if code in holding_codes else "low" + }) + except Exception: + continue + except Exception: + pass + + # 从数据库补充 + with get_session() as s: + db_events = s.execute( + select(CorporateEvent).where( + and_( + CorporateEvent.event_type == "earnings", + CorporateEvent.event_date >= today, + CorporateEvent.event_date <= end + ) + ).order_by(CorporateEvent.event_date) + ).scalars().all() + + for e in db_events: + if any(ev["code"] == e.code for ev in events): + continue + if holding_only and e.code not in holding_codes: + continue + events.append({ + "code": e.code, + "name": e.name, + "event_type": "财报披露", + "event_date": e.event_date.isoformat(), + "days_left": (e.event_date - today).days, + "report_type": e.title, + "is_holding": e.code in holding_codes, + "urgency": "high" if e.code in holding_codes else "low" + }) + + events.sort(key=lambda x: x["event_date"]) + return {"ok": True, "events": events, "count": len(events)} + + +def get_all_upcoming_events(days_ahead: int = 30) -> Dict: + """获取所有即将到来的关键事件(综合视图)""" + today = dt.date.today() + all_events = [] + + # 合并所有事件 + for result in [ + get_upcoming_dividends(days_ahead), + get_earnings_calendar(days_ahead), + get_unlock_calendar(days_ahead) + ]: + all_events.extend(result.get("events", [])) + + # 按日期排序 + all_events.sort(key=lambda x: x["event_date"]) + + # 按日期分组 + grouped = {} + for event in all_events: + date = event["event_date"] + if date not in grouped: + grouped[date] = [] + grouped[date].append(event) + + # 生成日历视图 + calendar = [] + for date_str, events in sorted(grouped.items()): + date = dt.date.fromisoformat(date_str) + calendar.append({ + "date": date_str, + "weekday": ["周一", "周二", "周三", "周四", "周五", "周六", "周日"][date.weekday()], + "days_left": (date - today).days, + "events": events, + "has_high_urgency": any(e["urgency"] == "high" for e in events), + "has_holding": any(e.get("is_holding", False) for e in events) + }) + + # 紧急事件(3天内) + urgent = [e for e in all_events if e.get("days_left", 99) <= 3] + + return { + "ok": True, + "calendar": calendar, + "urgent": urgent, + "total": len(all_events), + "summary": { + "dividends": len([e for e in all_events if e["event_type"] == "除权除息"]), + "earnings": len([e for e in all_events if e["event_type"] == "财报披露"]), + "unlocks": len([e for e in all_events if e["event_type"] == "限售解禁"]), + "urgent": len(urgent) + } + } + + +def check_and_push_calendar_alerts() -> Dict: + """检查并推送日历事件预警(定时任务调用)""" + try: + from notifier import notify + except Exception: + return {"ok": False, "msg": "推送模块不可用"} + + result = get_all_upcoming_events(days_ahead=7) + urgent = result.get("urgent", []) + + if not urgent: + return {"ok": True, "msg": "无紧急事件", "pushed": 0} + + # 生成推送内容 + lines = [f"📅 未来7天关键事件提醒({len(urgent)}条)\n"] + for event in urgent[:10]: # 最多推送10条 + urgency_icon = "🔴" if event["urgency"] == "high" else "🟡" + holding_icon = "💰" if event.get("is_holding") else "" + lines.append( + f"{urgency_icon}{holding_icon} {event['event_date']} " + f"{event['name']}({event['code']}) " + f"{event['event_type']} " + f"({event['days_left']}天后)" + ) + + message = "\n".join(lines) + notify("【Blackdata】关键事件提醒", message) + + # 写入站内通知 + with get_session() as s: + for event in urgent[:10]: + alert = AlertEvent( + rule_id=0, + code=event["code"], + name=event["name"], + message=f"{event['event_type']}: {event.get('detail', '')} ({event['days_left']}天后)", + value=event.get("amount_billion", 0) + ) + s.add(alert) + s.commit() + + return {"ok": True, "pushed": len(urgent), "msg": f"已推送 {len(urgent)} 条事件提醒"} diff --git a/backend/watchlist_manager.py b/backend/watchlist_manager.py new file mode 100644 index 0000000..a6dea12 --- /dev/null +++ b/backend/watchlist_manager.py @@ -0,0 +1,345 @@ +"""自选股分组管理""" +import datetime as dt +from typing import List, Dict, Optional +from sqlalchemy import select, func +from db import get_session +from models import WatchlistGroup, WatchlistItem, Security +import akshare_service as svc + +# 预设分组 +DEFAULT_GROUPS = [ + {"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True}, + {"name": "观察池", "description": "待观察的潜力股", "color": "blue"}, + {"name": "持仓股", "description": "当前持仓的股票", "color": "green"}, + {"name": "概念股", "description": "热门概念板块", "color": "purple"}, +] + +def init_default_groups(): + """初始化默认分组(如果不存在)""" + with get_session() as s: + count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar() + if count == 0: + for idx, g in enumerate(DEFAULT_GROUPS): + group = WatchlistGroup( + name=g["name"], + description=g["description"], + color=g["color"], + is_default=g.get("is_default", False), + sort_order=idx + ) + s.add(group) + s.commit() + print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)} 个") + return True + +def get_all_groups() -> List[Dict]: + """获取所有分组""" + with get_session() as s: + groups = s.execute( + select(WatchlistGroup).order_by(WatchlistGroup.sort_order) + ).scalars().all() + + result = [] + for g in groups: + # 统计分组内股票数量 + count = s.execute( + select(func.count()).select_from(WatchlistItem) + .where(WatchlistItem.group_id == g.id) + ).scalar() + + result.append({ + "id": g.id, + "name": g.name, + "description": g.description, + "color": g.color, + "count": count, + "is_default": g.is_default, + "sort_order": g.sort_order + }) + + return result + +def create_group(name: str, description: str = "", color: str = "blue") -> Dict: + """创建新分组""" + with get_session() as s: + # 获取当前最大排序号 + max_order = s.execute( + select(func.max(WatchlistGroup.sort_order)) + ).scalar() or 0 + + group = WatchlistGroup( + name=name, + description=description, + color=color, + sort_order=max_order + 1 + ) + s.add(group) + s.commit() + + return { + "ok": True, + "id": group.id, + "name": group.name + } + +def update_group(group_id: int, name: Optional[str] = None, + description: Optional[str] = None, color: Optional[str] = None) -> Dict: + """更新分组信息""" + with get_session() as s: + group = s.get(WatchlistGroup, group_id) + if not group: + return {"ok": False, "msg": "分组不存在"} + + if name is not None: + group.name = name + if description is not None: + group.description = description + if color is not None: + group.color = color + + s.commit() + return {"ok": True} + +def delete_group(group_id: int) -> Dict: + """删除分组(同时删除分组内的股票)""" + with get_session() as s: + group = s.get(WatchlistGroup, group_id) + if not group: + return {"ok": False, "msg": "分组不存在"} + + if group.is_default: + return {"ok": False, "msg": "默认分组不能删除"} + + # 删除分组内的股票 + s.execute( + WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id) + ) + + # 删除分组 + s.delete(group) + s.commit() + + return {"ok": True} + +def reorder_groups(group_ids: List[int]) -> Dict: + """重新排序分组""" + with get_session() as s: + for idx, gid in enumerate(group_ids): + group = s.get(WatchlistGroup, gid) + if group: + group.sort_order = idx + s.commit() + return {"ok": True} + +def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict: + """获取分组内的股票列表""" + with get_session() as s: + group = s.get(WatchlistGroup, group_id) + if not group: + return {"ok": False, "msg": "分组不存在"} + + items = s.execute( + select(WatchlistItem) + .where(WatchlistItem.group_id == group_id) + .order_by(WatchlistItem.sort_order) + ).scalars().all() + + codes = [item.code for item in items] + + # 获取实时行情 + stocks = [] + if with_quotes and codes: + quotes_data = svc.get_watchlist(codes) + quotes_map = {s["code"]: s for s in quotes_data.get("list", [])} + + for item in items: + quote = quotes_map.get(item.code, {}) + stocks.append({ + "id": item.id, + "code": item.code, + "name": item.name or quote.get("name", ""), + "price": quote.get("price", 0), + "pct": quote.get("pct", 0), + "change": quote.get("change", 0), + "amount": quote.get("amount", 0), + "note": item.note, + "added_at": item.added_at.strftime("%Y-%m-%d") + }) + else: + for item in items: + stocks.append({ + "id": item.id, + "code": item.code, + "name": item.name, + "note": item.note, + "added_at": item.added_at.strftime("%Y-%m-%d") + }) + + return { + "ok": True, + "group": { + "id": group.id, + "name": group.name, + "description": group.description, + "color": group.color + }, + "stocks": stocks + } + +def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict: + """添加股票到分组""" + with get_session() as s: + group = s.get(WatchlistGroup, group_id) + if not group: + return {"ok": False, "msg": "分组不存在"} + + # 检查是否已存在 + exists = s.execute( + select(WatchlistItem) + .where(WatchlistItem.group_id == group_id, WatchlistItem.code == code) + ).scalar_one_or_none() + + if exists: + return {"ok": False, "msg": "该股票已在分组中"} + + # 获取股票名称 + sec = s.get(Security, code) + name = sec.name if sec else code + + # 获取当前最大排序号 + max_order = s.execute( + select(func.max(WatchlistItem.sort_order)) + .where(WatchlistItem.group_id == group_id) + ).scalar() or 0 + + item = WatchlistItem( + group_id=group_id, + code=code, + name=name, + note=note, + sort_order=max_order + 1 + ) + s.add(item) + s.commit() + + return {"ok": True, "id": item.id} + +def remove_stock_from_group(item_id: int) -> Dict: + """从分组中移除股票""" + with get_session() as s: + item = s.get(WatchlistItem, item_id) + if not item: + return {"ok": False, "msg": "股票不存在"} + + s.delete(item) + s.commit() + return {"ok": True} + +def move_stock_to_group(item_id: int, target_group_id: int) -> Dict: + """将股票移动到另一个分组""" + with get_session() as s: + item = s.get(WatchlistItem, item_id) + if not item: + return {"ok": False, "msg": "股票不存在"} + + target_group = s.get(WatchlistGroup, target_group_id) + if not target_group: + return {"ok": False, "msg": "目标分组不存在"} + + # 检查目标分组是否已有该股票 + exists = s.execute( + select(WatchlistItem) + .where(WatchlistItem.group_id == target_group_id, + WatchlistItem.code == item.code) + ).scalar_one_or_none() + + if exists: + return {"ok": False, "msg": "目标分组已有该股票"} + + item.group_id = target_group_id + s.commit() + return {"ok": True} + +def batch_add_stocks(group_id: int, codes: List[str]) -> Dict: + """批量添加股票到分组""" + with get_session() as s: + group = s.get(WatchlistGroup, group_id) + if not group: + return {"ok": False, "msg": "分组不存在"} + + added = 0 + skipped = 0 + + for code in codes: + # 检查是否已存在 + exists = s.execute( + select(WatchlistItem) + .where(WatchlistItem.group_id == group_id, WatchlistItem.code == code) + ).scalar_one_or_none() + + if exists: + skipped += 1 + continue + + # 获取股票名称 + sec = s.get(Security, code) + name = sec.name if sec else code + + item = WatchlistItem( + group_id=group_id, + code=code, + name=name, + sort_order=added + ) + s.add(item) + added += 1 + + s.commit() + return {"ok": True, "added": added, "skipped": skipped} + +def update_stock_note(item_id: int, note: str) -> Dict: + """更新股票备注""" + with get_session() as s: + item = s.get(WatchlistItem, item_id) + if not item: + return {"ok": False, "msg": "股票不存在"} + + item.note = note + s.commit() + return {"ok": True} + +def reorder_stocks(item_ids: List[int]) -> Dict: + """重新排序分组内的股票""" + with get_session() as s: + for idx, item_id in enumerate(item_ids): + item = s.get(WatchlistItem, item_id) + if item: + item.sort_order = idx + s.commit() + return {"ok": True} + +def search_stocks_across_groups(keyword: str) -> List[Dict]: + """跨分组搜索股票""" + with get_session() as s: + items = s.execute( + select(WatchlistItem, WatchlistGroup) + .join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id) + .where( + (WatchlistItem.code.like(f"%{keyword}%")) | + (WatchlistItem.name.like(f"%{keyword}%")) + ) + ).all() + + result = [] + for item, group in items: + result.append({ + "id": item.id, + "code": item.code, + "name": item.name, + "group_id": group.id, + "group_name": group.name, + "group_color": group.color, + "note": item.note + }) + + return result diff --git a/prototype/app.js b/prototype/app.js index 3783039..d33e9d7 100644 --- a/prototype/app.js +++ b/prototype/app.js @@ -11,11 +11,18 @@ function disposeCharts() { while (charts.length) charts.pop().dispose(); } /* ===================== API 层(带降级) ===================== */ const API_BASE = location.port === '8000' ? '' : 'http://localhost:8000'; let LAST_SOURCE = '-'; +let _token = localStorage.getItem('auth_token') || ''; + +function authHeaders() { + return _token ? { 'Authorization': 'Bearer ' + _token } : {}; +} + async function apiGet(path) { const ctl = new AbortController(); const t = setTimeout(() => ctl.abort(), 8000); try { - const res = await fetch(API_BASE + path, { signal: ctl.signal }); + const res = await fetch(API_BASE + path, { signal: ctl.signal, headers: authHeaders() }); + if (res.status === 401) { showLoginModal(); throw new Error('401'); } if (!res.ok) throw new Error(res.status); const json = await res.json(); LAST_SOURCE = json.source || 'akshare'; @@ -33,11 +40,51 @@ function updateSource() { const el = document.getElementById('dsource'); if (el) el.textContent = '数据源: ' + LAST_SOURCE; } -async function apiPost(path) { - const res = await fetch(API_BASE + path, { method: 'POST' }); +async function apiPost(path, body) { + const opts = { method: 'POST', headers: { ...authHeaders() } }; + if (body) { opts.headers['Content-Type'] = 'application/json'; opts.body = JSON.stringify(body); } + const res = await fetch(API_BASE + path, opts); return res.json(); } +function showLoginModal() { + if (document.getElementById('_login_modal')) return; + const bg = document.createElement('div'); + bg.id = '_login_modal'; + bg.style.cssText = 'position:fixed;inset:0;background:#00000099;z-index:20000;display:flex;align-items:center;justify-content:center'; + bg.innerHTML = `
+

?? 登录

+
+
+
+ +
+
+
`; + document.body.appendChild(bg); + const doLogin = async () => { + const u = document.getElementById('_li_user').value.trim(); + const p = document.getElementById('_li_pass').value; + try { + const r = await fetch(API_BASE + '/api/auth/login', { method:'POST', + headers:{'Content-Type':'application/json'}, body: JSON.stringify({username:u,password:p}) }); + const j = await r.json(); + if (j.access_token) { + _token = j.access_token; + localStorage.setItem('auth_token', _token); + bg.remove(); + // 刷新当前视图 + const cur = location.hash.slice(1); + if (cur && VIEW_INDEX[cur]) navigate(cur); + } else { + document.getElementById('_li_msg').textContent = j.detail || '用户名或密码错误'; + } + } catch { document.getElementById('_li_msg').textContent = '后端未连接'; } + }; + document.getElementById('_li_btn').onclick = doLogin; + document.getElementById('_li_pass').onkeydown = e => { if(e.key==='Enter') doLogin(); }; +} + /* ===================== 菜单配置(一级 / 二级) ===================== */ const MENU = [ { icon: '▤', name: '行情中心', children: [ @@ -83,6 +130,7 @@ const MENU = [ { id: 'pf-equity', name: '资金曲线' }, { id: 'pf-trades', name: '交易日志' }, { id: 'pf-attr', name: '盈亏归因' }, + { id: 'paper-trading', name: '模拟盘' }, ]}, { icon: '✉', name: '资讯中心', children: [ { id: 'news-main', name: '要闻快讯' }, @@ -92,6 +140,10 @@ const MENU = [ { id: 'alert-list', name: '预警规则' }, { id: 'alert-events', name: '触发记录' }, ]}, + { icon: '??', name: '用户中心', children: [ + { id: 'user-profile', name: '我的账户' }, + { id: 'user-manage', name: '用户管理' }, + ]}, ]; const VIEW_INDEX = {}; MENU.forEach(g => g.children.forEach(c => { VIEW_INDEX[c.id] = { group: g.name, name: c.name, soon: c.soon }; })); @@ -101,7 +153,7 @@ function renderMenu() { const nav = document.getElementById('menu'); nav.innerHTML = MENU.map((g, gi) => `