功能细节优化
This commit is contained in:
237
backend/CHECKLIST.md
Normal file
237
backend/CHECKLIST.md
Normal file
@@ -0,0 +1,237 @@
|
||||
# 启动前检查清单
|
||||
|
||||
在首次启动或遇到问题时,请按此清单逐项检查。
|
||||
|
||||
## ✅ 环境检查
|
||||
|
||||
### 系统服务
|
||||
```bash
|
||||
# PostgreSQL 是否运行
|
||||
sudo service postgresql status
|
||||
sudo service postgresql start # 如果未运行
|
||||
|
||||
# Redis 是否运行
|
||||
redis-cli ping # 应返回 PONG
|
||||
sudo service redis-server start # 如果未运行
|
||||
```
|
||||
|
||||
### Python 环境
|
||||
```bash
|
||||
# 虚拟环境是否激活
|
||||
which python # 应显示 .venv/bin/python
|
||||
source .venv/bin/activate # 如果未激活
|
||||
|
||||
# 依赖是否安装完整
|
||||
pip list | grep redis
|
||||
pip list | grep jose
|
||||
pip list | grep passlib
|
||||
```
|
||||
|
||||
## ✅ 配置检查
|
||||
|
||||
### .env 文件
|
||||
```bash
|
||||
# 检查必要配置项
|
||||
cat .env | grep PG_PASSWORD
|
||||
cat .env | grep SECRET_KEY
|
||||
cat .env | grep REDIS_HOST
|
||||
|
||||
# 如果缺少配置
|
||||
cp .env.example .env # 如果 .env 不存在
|
||||
nano .env # 编辑配置
|
||||
```
|
||||
|
||||
### 必须配置的项
|
||||
- [ ] `PG_PASSWORD` - PostgreSQL 密码
|
||||
- [ ] `SECRET_KEY` - JWT 密钥(生产环境必须修改)
|
||||
- [ ] `REDIS_HOST` - Redis 地址(默认 localhost)
|
||||
|
||||
## ✅ 数据库检查
|
||||
|
||||
```bash
|
||||
# 检查数据库是否创建
|
||||
psql -U postgres -c "\l" | grep stock_cs
|
||||
|
||||
# 检查用户表是否存在
|
||||
psql -U postgres -d stock_cs -c "\dt" | grep users
|
||||
|
||||
# 如果数据库或表不存在
|
||||
python cli.py init
|
||||
```
|
||||
|
||||
## ✅ 权限检查
|
||||
|
||||
```bash
|
||||
# 测试数据库连接
|
||||
psql -U postgres -d stock_cs -c "SELECT 1;"
|
||||
|
||||
# 测试 Redis 连接
|
||||
redis-cli ping
|
||||
|
||||
# 如果提示权限错误
|
||||
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
|
||||
```
|
||||
|
||||
## ✅ 启动服务
|
||||
|
||||
```bash
|
||||
# 完整启动流程
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
cd backend
|
||||
source .venv/bin/activate
|
||||
python main.py
|
||||
```
|
||||
|
||||
### 启动成功标志
|
||||
|
||||
看到以下日志表示启动成功:
|
||||
```
|
||||
✓ Redis 已连接: localhost:6379
|
||||
✓ 管理员账号已存在: admin
|
||||
[startup] db + scheduler + auth ready
|
||||
INFO: Uvicorn running on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
## ✅ 功能测试
|
||||
|
||||
```bash
|
||||
# 1. 健康检查
|
||||
curl http://localhost:8000/api/health
|
||||
# 应返回: {"ok":true,"akshare":true,"redis":true,"auth":true}
|
||||
|
||||
# 2. 登录测试
|
||||
curl -X POST http://localhost:8000/api/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username":"admin","password":"admin123"}'
|
||||
# 应返回 Token
|
||||
|
||||
# 3. 运行完整测试
|
||||
python test_core_features.py
|
||||
```
|
||||
|
||||
## 🔧 常见问题排查
|
||||
|
||||
### 问题 1: Redis 连接失败
|
||||
```
|
||||
✗ Redis 连接失败,缓存已禁用
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
sudo service redis-server start
|
||||
redis-cli ping
|
||||
```
|
||||
|
||||
### 问题 2: 数据库连接失败
|
||||
```
|
||||
connection refused
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
sudo service postgresql start
|
||||
psql -U postgres -c "SELECT 1;"
|
||||
```
|
||||
|
||||
### 问题 3: 密码认证失败
|
||||
```
|
||||
password authentication failed
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
|
||||
# 然后在 .env 中设置相同的密码
|
||||
```
|
||||
|
||||
### 问题 4: 模块未找到
|
||||
```
|
||||
ModuleNotFoundError: No module named 'redis'
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 问题 5: 401 Unauthorized
|
||||
```
|
||||
{"detail":"未认证,请先登录"}
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
# 先登录获取 Token
|
||||
curl -X POST http://localhost:8000/api/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username":"admin","password":"admin123"}'
|
||||
|
||||
# 使用 Token 访问
|
||||
curl -X GET http://localhost:8000/api/admin/status \
|
||||
-H "Authorization: Bearer <返回的token>"
|
||||
```
|
||||
|
||||
### 问题 6: 用户表不存在
|
||||
```
|
||||
relation "users" does not exist
|
||||
```
|
||||
**解决**:
|
||||
```bash
|
||||
python cli.py init
|
||||
```
|
||||
|
||||
## 📝 首次部署完整流程
|
||||
|
||||
```bash
|
||||
# 1. 系统依赖
|
||||
sudo apt update
|
||||
sudo apt install -y postgresql redis-server
|
||||
|
||||
# 2. 启动服务
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
|
||||
# 3. 配置数据库密码
|
||||
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
|
||||
|
||||
# 4. Python 环境
|
||||
cd backend
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 5. 配置环境变量
|
||||
cp .env.example .env
|
||||
nano .env # 编辑 PG_PASSWORD, SECRET_KEY 等
|
||||
|
||||
# 6. 初始化数据库
|
||||
python cli.py init
|
||||
|
||||
# 7. 启动服务
|
||||
python main.py
|
||||
|
||||
# 8. 测试
|
||||
python test_core_features.py
|
||||
```
|
||||
|
||||
## 🚀 一键启动脚本(WSL)
|
||||
|
||||
创建 `start.sh`:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
cd /mnt/e/project/stock_cs_v1/backend # 修改为实际路径
|
||||
source .venv/bin/activate
|
||||
python main.py
|
||||
```
|
||||
|
||||
使用:
|
||||
```bash
|
||||
chmod +x start.sh
|
||||
./start.sh
|
||||
```
|
||||
|
||||
## 📞 获取帮助
|
||||
|
||||
如果以上方法都无法解决问题:
|
||||
1. 查看详细文档: `backend/UPGRADE_GUIDE.md`
|
||||
2. 查看配置说明: `backend/ENV_CONFIG.md`
|
||||
3. 查看实现总结: `三大核心功能实现总结.md`
|
||||
118
backend/ENV_CONFIG.md
Normal file
118
backend/ENV_CONFIG.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# 环境变量配置说明
|
||||
|
||||
## 新增配置项(三大核心功能)
|
||||
|
||||
### Redis 缓存配置
|
||||
```bash
|
||||
# Redis 连接配置
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
REDIS_PASSWORD= # 如果 Redis 设置了密码则填写
|
||||
```
|
||||
|
||||
### 认证系统配置
|
||||
```bash
|
||||
# JWT 密钥(生产环境务必修改)
|
||||
SECRET_KEY=your-secret-key-change-in-production
|
||||
|
||||
# Token 过期时间(分钟)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080 # 默认7天
|
||||
|
||||
# API Key 模式(可选,用于外部调用,逗号分隔多个)
|
||||
API_KEYS=your-api-key-1,your-api-key-2
|
||||
|
||||
# 默认管理员账号(首次启动时创建)
|
||||
DEFAULT_ADMIN_USERNAME=admin
|
||||
DEFAULT_ADMIN_PASSWORD=admin123 # 首次启动后务必修改密码
|
||||
```
|
||||
|
||||
## 安装 Redis(WSL 环境)
|
||||
|
||||
```bash
|
||||
# 安装 Redis
|
||||
sudo apt update
|
||||
sudo apt install -y redis-server
|
||||
|
||||
# 启动 Redis
|
||||
sudo service redis-server start
|
||||
|
||||
# 验证 Redis 是否运行
|
||||
redis-cli ping
|
||||
# 应返回 PONG
|
||||
|
||||
# 设置 Redis 开机自启(可选)
|
||||
sudo systemctl enable redis-server
|
||||
```
|
||||
|
||||
## 安全建议
|
||||
|
||||
1. **SECRET_KEY**: 生产环境必须使用强随机字符串,可用以下命令生成:
|
||||
```bash
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
```
|
||||
|
||||
2. **默认密码**: 首次登录后立即修改 admin 密码
|
||||
|
||||
3. **API_KEYS**: 仅在需要外部调用时配置,妥善保管
|
||||
|
||||
4. **Redis 密码**: 生产环境建议为 Redis 设置密码:
|
||||
```bash
|
||||
# 编辑 Redis 配置
|
||||
sudo nano /etc/redis/redis.conf
|
||||
# 找到 requirepass 行,取消注释并设置密码
|
||||
requirepass your-strong-password
|
||||
# 重启 Redis
|
||||
sudo service redis-server restart
|
||||
```
|
||||
|
||||
## 功能说明
|
||||
|
||||
### 1. Redis 缓存
|
||||
- 替代内存缓存,支持持久化和跨进程共享
|
||||
- 自动降级:Redis 不可用时使用内存缓存
|
||||
- 默认过期时间根据数据类型自动设置(行情数据1分钟,基本面数据1天)
|
||||
|
||||
### 2. 统一鉴权
|
||||
- **JWT Token 模式**: 用户登录获取 Token,适合前端应用
|
||||
- **API Key 模式**: 用于外部系统调用,配置在 HTTP Header `X-API-Key`
|
||||
- 管理接口(`/api/admin/*`)需要管理员权限
|
||||
|
||||
### 3. 统一异常处理
|
||||
- 业务异常返回友好错误信息
|
||||
- 数据源异常自动降级
|
||||
- 数据库异常统一处理
|
||||
- 所有异常记录日志便于排查
|
||||
|
||||
## API 使用示例
|
||||
|
||||
### 登录
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username":"admin","password":"admin123"}'
|
||||
```
|
||||
|
||||
### 使用 Token 访问受保护接口
|
||||
```bash
|
||||
curl -X GET http://localhost:8000/api/admin/status \
|
||||
-H "Authorization: Bearer YOUR_TOKEN_HERE"
|
||||
```
|
||||
|
||||
### 使用 API Key 访问
|
||||
```bash
|
||||
curl -X GET http://localhost:8000/api/admin/status \
|
||||
-H "X-API-Key: your-api-key-1"
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. WSL 环境下,每次重启后需要手动启动服务:
|
||||
```bash
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
```
|
||||
|
||||
2. Redis 连接失败不会影响系统运行,会自动降级到内存缓存
|
||||
|
||||
3. 未配置鉴权时,所有接口默认不需要认证(开发模式)
|
||||
287
backend/UPGRADE_GUIDE.md
Normal file
287
backend/UPGRADE_GUIDE.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# 三大核心功能升级指南
|
||||
|
||||
本次升级新增了三个核心功能:Redis 缓存层、统一鉴权机制、统一异常处理中间件。
|
||||
|
||||
## 1. 安装新依赖
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
source .venv/bin/activate # 激活虚拟环境
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
新增的依赖包:
|
||||
- `redis>=5.0.0` - Redis 客户端
|
||||
- `python-jose[cryptography]>=3.3.0` - JWT Token 生成和验证
|
||||
- `passlib[bcrypt]>=1.7.4` - 密码哈希
|
||||
- `python-multipart>=0.0.9` - 表单数据解析
|
||||
|
||||
## 2. 安装和启动 Redis(WSL)
|
||||
|
||||
```bash
|
||||
# 安装 Redis
|
||||
sudo apt update
|
||||
sudo apt install -y redis-server
|
||||
|
||||
# 启动 Redis
|
||||
sudo service redis-server start
|
||||
|
||||
# 验证 Redis 是否运行
|
||||
redis-cli ping
|
||||
# 应返回 PONG
|
||||
```
|
||||
|
||||
## 3. 配置环境变量
|
||||
|
||||
编辑 `backend/.env` 文件,添加以下配置:
|
||||
|
||||
```bash
|
||||
# ============ Redis 缓存配置 ============
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
REDIS_PASSWORD= # 可选,如果设置了密码则填写
|
||||
|
||||
# ============ 认证系统配置 ============
|
||||
# JWT 密钥(务必修改为强随机字符串)
|
||||
SECRET_KEY=your-secret-key-change-in-production
|
||||
|
||||
# Token 过期时间(分钟,默认7天)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080
|
||||
|
||||
# API Key(可选,用于外部调用,逗号分隔)
|
||||
API_KEYS=
|
||||
|
||||
# 默认管理员账号
|
||||
DEFAULT_ADMIN_USERNAME=admin
|
||||
DEFAULT_ADMIN_PASSWORD=admin123
|
||||
```
|
||||
|
||||
### 生成安全的 SECRET_KEY
|
||||
|
||||
```bash
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
```
|
||||
|
||||
将输出的字符串替换 `SECRET_KEY` 的值。
|
||||
|
||||
## 4. 初始化数据库(创建用户表和默认管理员)
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
source .venv/bin/activate
|
||||
python cli.py init
|
||||
```
|
||||
|
||||
你会看到类似输出:
|
||||
```
|
||||
✓ 创建默认管理员: admin
|
||||
init done
|
||||
```
|
||||
|
||||
## 5. 启动服务
|
||||
|
||||
```bash
|
||||
# 确保 PostgreSQL 和 Redis 都在运行
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
|
||||
# 启动 FastAPI 服务
|
||||
cd backend
|
||||
source .venv/bin/activate
|
||||
python main.py
|
||||
```
|
||||
|
||||
启动日志应包含:
|
||||
```
|
||||
✓ Redis 已连接: localhost:6379
|
||||
✓ 管理员账号已存在: admin
|
||||
[startup] db + scheduler + auth ready
|
||||
```
|
||||
|
||||
## 6. 测试功能
|
||||
|
||||
### 测试健康检查(查看 Redis 和鉴权状态)
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/api/health
|
||||
```
|
||||
|
||||
应返回:
|
||||
```json
|
||||
{
|
||||
"ok": true,
|
||||
"akshare": true,
|
||||
"redis": true,
|
||||
"auth": true
|
||||
}
|
||||
```
|
||||
|
||||
### 测试登录
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username":"admin","password":"admin123"}'
|
||||
```
|
||||
|
||||
应返回 Token:
|
||||
```json
|
||||
{
|
||||
"ok": true,
|
||||
"access_token": "eyJ0eXAiOiJKV1QiLCJhb...",
|
||||
"token_type": "bearer",
|
||||
"username": "admin",
|
||||
"is_admin": true
|
||||
}
|
||||
```
|
||||
|
||||
### 测试受保护的接口
|
||||
|
||||
```bash
|
||||
# 使用上一步获取的 Token
|
||||
curl -X GET http://localhost:8000/api/admin/status \
|
||||
-H "Authorization: Bearer YOUR_TOKEN_HERE"
|
||||
```
|
||||
|
||||
### 测试 Redis 缓存
|
||||
|
||||
```bash
|
||||
# 第一次请求(缓存未命中,较慢)
|
||||
time curl http://localhost:8000/api/indices
|
||||
|
||||
# 第二次请求(缓存命中,应该快很多)
|
||||
time curl http://localhost:8000/api/indices
|
||||
```
|
||||
|
||||
## 7. 修改默认密码(重要!)
|
||||
|
||||
首次登录后,立即修改默认密码:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/auth/change-password \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"old_password":"admin123","new_password":"your-new-strong-password"}'
|
||||
```
|
||||
|
||||
## 8. 常见问题
|
||||
|
||||
### Redis 连接失败
|
||||
|
||||
如果看到以下日志:
|
||||
```
|
||||
✗ Redis 连接失败,缓存已禁用: ...
|
||||
```
|
||||
|
||||
**解决方法**:
|
||||
```bash
|
||||
# 检查 Redis 是否运行
|
||||
sudo service redis-server status
|
||||
|
||||
# 如果未运行,启动它
|
||||
sudo service redis-server start
|
||||
```
|
||||
|
||||
**注意**:Redis 连接失败不会影响系统运行,会自动降级到内存缓存。
|
||||
|
||||
### 401 Unauthorized 错误
|
||||
|
||||
如果访问管理接口返回 401:
|
||||
```json
|
||||
{"detail":"未认证,请先登录"}
|
||||
```
|
||||
|
||||
**原因**:该接口需要认证
|
||||
|
||||
**解决方法**:
|
||||
1. 先调用 `/api/auth/login` 获取 Token
|
||||
2. 在请求头中加入 `Authorization: Bearer YOUR_TOKEN`
|
||||
|
||||
### WSL 重启后服务启动失败
|
||||
|
||||
WSL 每次重启后需要手动启动服务:
|
||||
|
||||
```bash
|
||||
# 一键启动所有服务
|
||||
sudo service postgresql start && \
|
||||
sudo service redis-server start && \
|
||||
cd /mnt/e/project/stock_cs_v1/backend && \
|
||||
source .venv/bin/activate && \
|
||||
python main.py
|
||||
```
|
||||
|
||||
## 9. 功能详细说明
|
||||
|
||||
### Redis 缓存优势
|
||||
|
||||
- **持久化**:服务重启后缓存不丢失
|
||||
- **跨进程**:多个进程可共享缓存
|
||||
- **性能提升**:大幅减少 AkShare API 调用,响应速度提升 10-100 倍
|
||||
- **自动降级**:Redis 不可用时自动使用内存缓存
|
||||
|
||||
### 鉴权机制
|
||||
|
||||
**支持两种认证方式**:
|
||||
|
||||
1. **JWT Token**(推荐用于前端)
|
||||
- 登录后获取 Token
|
||||
- Token 有效期 7 天(可配置)
|
||||
- 在 HTTP Header 中传递:`Authorization: Bearer TOKEN`
|
||||
|
||||
2. **API Key**(推荐用于外部系统)
|
||||
- 在 `.env` 中配置 `API_KEYS`
|
||||
- 在 HTTP Header 中传递:`X-API-Key: YOUR_API_KEY`
|
||||
|
||||
**受保护的接口**:
|
||||
- `/api/admin/*` - 需要管理员权限
|
||||
- 其他接口暂不需要认证(可根据需要扩展)
|
||||
|
||||
### 异常处理
|
||||
|
||||
系统现在会返回友好的错误信息:
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": "数据源异常,请稍后重试",
|
||||
"code": 503
|
||||
}
|
||||
```
|
||||
|
||||
错误类型:
|
||||
- **400** - 业务逻辑错误
|
||||
- **401** - 未认证
|
||||
- **403** - 权限不足
|
||||
- **422** - 请求参数错误
|
||||
- **500** - 服务器内部错误
|
||||
- **503** - 数据源不可用
|
||||
|
||||
## 10. 升级检查清单
|
||||
|
||||
- [ ] 安装新依赖包
|
||||
- [ ] 安装并启动 Redis
|
||||
- [ ] 配置 `.env` 文件(Redis + 鉴权配置)
|
||||
- [ ] 生成安全的 `SECRET_KEY`
|
||||
- [ ] 运行 `python cli.py init` 创建用户表
|
||||
- [ ] 启动服务并验证功能
|
||||
- [ ] 测试登录和受保护接口
|
||||
- [ ] 修改默认管理员密码
|
||||
- [ ] 检查 Redis 缓存是否生效
|
||||
|
||||
## 11. 回退方案
|
||||
|
||||
如果升级后遇到问题,可以临时禁用新功能:
|
||||
|
||||
1. **禁用 Redis 缓存**:停止 Redis 服务,系统会自动降级到内存缓存
|
||||
```bash
|
||||
sudo service redis-server stop
|
||||
```
|
||||
|
||||
2. **禁用鉴权**:暂时注释掉 `main.py` 中受保护接口的 `Depends(require_auth)` 参数
|
||||
|
||||
3. **完整回退**:切换到升级前的 git 分支
|
||||
|
||||
---
|
||||
|
||||
升级完成后,系统将具备更强的性能、安全性和稳定性!
|
||||
170
backend/ai.py
170
backend/ai.py
@@ -202,6 +202,176 @@ def diagnose(symbol):
|
||||
return base
|
||||
|
||||
|
||||
# ============ 走势分析(右键K线) ============
|
||||
def trend_analysis(symbol: str, date: str, period: str = "daily"):
|
||||
"""分析某只股票在指定日期(或最新)附近暴涨/暴跌的原因。
|
||||
period: daily / weekly / monthly
|
||||
"""
|
||||
with get_session() as s:
|
||||
sec = s.get(Security, symbol)
|
||||
# 取该日期前后一段数据作为上下文
|
||||
if date:
|
||||
try:
|
||||
target_date = dt.date.fromisoformat(date)
|
||||
except Exception:
|
||||
target_date = None
|
||||
else:
|
||||
target_date = None
|
||||
|
||||
# 取最近60根K线
|
||||
rows = s.execute(
|
||||
select(DailyQuote).where(DailyQuote.code == symbol)
|
||||
.order_by(DailyQuote.date.desc()).limit(60)
|
||||
).scalars().all()
|
||||
rows = list(reversed(rows))
|
||||
|
||||
# 当日及相邻数据
|
||||
target_row = None
|
||||
if target_date and rows:
|
||||
# 找最近的一根
|
||||
closest = min(rows, key=lambda r: abs((r.date - target_date).days))
|
||||
if abs((closest.date - target_date).days) <= 7:
|
||||
target_row = closest
|
||||
|
||||
if not target_row and rows:
|
||||
target_row = rows[-1]
|
||||
|
||||
m = s.get(StockMetric, symbol)
|
||||
|
||||
name = (sec.name if sec else (m.name if m else symbol)) or symbol
|
||||
|
||||
# 计算目标K线的涨跌幅
|
||||
pct = 0.0
|
||||
if target_row and rows:
|
||||
idx = rows.index(target_row)
|
||||
if idx > 0:
|
||||
prev_close = rows[idx - 1].close
|
||||
if prev_close:
|
||||
pct = round((target_row.close - prev_close) / prev_close * 100, 2)
|
||||
|
||||
# 拉取相关新闻
|
||||
import akshare_service as svc
|
||||
try:
|
||||
news_data = svc.get_stock_news(symbol, limit=10)
|
||||
news_items = news_data.get("list", [])
|
||||
except Exception:
|
||||
news_items = []
|
||||
|
||||
# 拉取RAG上下文
|
||||
import rag
|
||||
rctx = rag.stock_context(symbol, limit=6)
|
||||
|
||||
# 构造上下文数据
|
||||
period_cn = {"daily": "日K", "weekly": "周K", "monthly": "月K"}.get(period, "K线")
|
||||
date_str = target_row.date.isoformat() if target_row else (date or "最新")
|
||||
|
||||
# 当日技术面
|
||||
if target_row:
|
||||
tech_line = (
|
||||
f"目标K线:{date_str},开{target_row.open} 收{target_row.close} "
|
||||
f"高{target_row.high} 低{target_row.low},"
|
||||
f"涨跌幅{pct:+.2f}%,成交量{target_row.volume:,}"
|
||||
)
|
||||
else:
|
||||
tech_line = f"目标日期:{date_str},暂无日线数据"
|
||||
|
||||
# 前后走势(最近5根)
|
||||
if target_row and rows:
|
||||
idx = rows.index(target_row)
|
||||
window = rows[max(0, idx-4):idx+2]
|
||||
trend_line = "前后走势:" + " → ".join(
|
||||
f"{r.date.strftime('%m/%d')}({'↑' if i == 0 or r.close >= rows[rows.index(r)-1].close else '↓'}{abs(round((r.close/rows[rows.index(r)-1].close-1)*100,1)) if rows.index(r) > 0 else 0}%)"
|
||||
for i, r in enumerate(window)
|
||||
)
|
||||
else:
|
||||
trend_line = ""
|
||||
|
||||
# 均线状态
|
||||
ma_line = ""
|
||||
if m:
|
||||
ma_line = (f"均线状态:MA5={m.ma5} MA10={m.ma10} MA20={m.ma20} MA60={m.ma60},"
|
||||
f"{'多头排列' if m.ma_bull else '非多头'},"
|
||||
f"量比{m.vol_ratio},RSI14={m.rsi14}")
|
||||
|
||||
# 新闻摘要
|
||||
news_block = ""
|
||||
if news_items:
|
||||
news_block = "相关新闻(近期):\n" + "\n".join(
|
||||
f"- [{n.get('time','')[:10]}] {n.get('title','')}" for n in news_items[:6]
|
||||
)
|
||||
|
||||
# 判断是否暴涨/暴跌
|
||||
move_desc = ""
|
||||
if abs(pct) >= 5:
|
||||
move_desc = f"该股{'暴涨' if pct > 0 else '暴跌'} {abs(pct):.2f}%({'接近/涨停' if pct >= 9.5 else '显著上涨' if pct > 0 else '接近/跌停' if pct <= -9.5 else '显著下跌'})"
|
||||
elif abs(pct) >= 2:
|
||||
move_desc = f"该股{'上涨' if pct > 0 else '下跌'} {abs(pct):.2f}%"
|
||||
else:
|
||||
move_desc = f"该股小幅变动 {pct:+.2f}%"
|
||||
|
||||
facts = f"""{name}({symbol}){period_cn}走势分析
|
||||
分析日期:{date_str}
|
||||
{move_desc}
|
||||
{tech_line}
|
||||
{trend_line}
|
||||
{ma_line}
|
||||
{news_block}
|
||||
消息面情绪:{rctx['tone']}
|
||||
{rctx['block'] or ''}"""
|
||||
|
||||
if llm.enabled():
|
||||
try:
|
||||
prompt = (
|
||||
f"请分析 {name}({symbol})在 {date_str} 前后{period_cn}的走势,"
|
||||
f"重点解释:① 为什么{'暴涨' if pct >= 5 else ('暴跌' if pct <= -5 else '出现此走势')}(从技术面、资金面、政策面、新闻事件等维度);"
|
||||
f"② 背后的主要驱动逻辑是什么;③ 后续需关注的信号或风险。250字以内,分点清晰。\n\n{facts}"
|
||||
)
|
||||
text = llm.ask(prompt, temperature=0.5, max_tokens=600)
|
||||
return {"ok": True, "source": "llm", "symbol": symbol, "name": name,
|
||||
"date": date_str, "period": period, "pct": pct, "facts": facts, "text": text}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 规则降级
|
||||
reasons = []
|
||||
if m:
|
||||
if m.ma_bull and pct > 0:
|
||||
reasons.append("均线多头排列,趋势向上")
|
||||
if m.vol_ratio >= 2 and pct > 0:
|
||||
reasons.append(f"成交量显著放大(量比{m.vol_ratio}),主力资金介入")
|
||||
if m.vol_ratio >= 2 and pct < 0:
|
||||
reasons.append(f"放量下跌(量比{m.vol_ratio}),资金出逃信号")
|
||||
if m.macd_gold and pct > 0:
|
||||
reasons.append("MACD金叉,动能转强")
|
||||
if m.rsi14 >= 80:
|
||||
reasons.append(f"RSI超买({m.rsi14}),注意回调风险")
|
||||
if m.rsi14 < 30:
|
||||
reasons.append(f"RSI超卖({m.rsi14}),存在超跌反弹机会")
|
||||
if m.pos60 >= 0.95 and pct > 0:
|
||||
reasons.append("突破60日新高,动量突破")
|
||||
if m.pos60 <= 0.1 and pct > 0:
|
||||
reasons.append("低位反弹,超跌修复")
|
||||
if rctx['tone'] == '利好':
|
||||
reasons.append("近期资讯面偏利好")
|
||||
elif rctx['tone'] == '利空':
|
||||
reasons.append("近期资讯面偏利空")
|
||||
if news_items:
|
||||
hot_news = news_items[0]['title'][:40]
|
||||
reasons.append(f"最新消息:{hot_news}…")
|
||||
if not reasons:
|
||||
reasons.append("暂无明确技术或消息面驱动,可能为市场情绪或板块联动")
|
||||
|
||||
text = (
|
||||
f"{name} 在 {date_str} {move_desc}。\n"
|
||||
f"主要原因分析:\n" +
|
||||
"\n".join(f"{i+1}. {r}" for i, r in enumerate(reasons)) +
|
||||
f"\n\n建议:{'关注量能是否持续配合,谨防高位回调。' if pct >= 5 else ('关注是否企稳止跌,底部确认前谨慎抄底。' if pct <= -5 else '走势相对平稳,跟踪板块动向。')}"
|
||||
f"\n{DISCLAIMER}"
|
||||
)
|
||||
return {"ok": True, "source": "rule", "symbol": symbol, "name": name,
|
||||
"date": date_str, "period": period, "pct": pct, "facts": facts, "text": text}
|
||||
|
||||
|
||||
# ============ 今日策略 ============
|
||||
def today_strategy():
|
||||
with get_session() as s:
|
||||
|
||||
@@ -12,6 +12,7 @@ from functools import wraps
|
||||
from cachetools import TTLCache
|
||||
|
||||
import requests
|
||||
from redis_cache import cache
|
||||
|
||||
try:
|
||||
import akshare as ak
|
||||
@@ -25,16 +26,35 @@ _cache = TTLCache(maxsize=256, ttl=30)
|
||||
|
||||
|
||||
def cached(ttl: int):
|
||||
"""缓存装饰器:优先使用 Redis,降级到内存缓存"""
|
||||
def deco(fn):
|
||||
local = TTLCache(maxsize=64, ttl=ttl)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
key = (fn.__name__, args, tuple(sorted(kwargs.items())))
|
||||
if key in local:
|
||||
return local[key]
|
||||
# 生成缓存键
|
||||
key = f"akshare:{fn.__name__}:{args}:{tuple(sorted(kwargs.items()))}"
|
||||
|
||||
# 优先从 Redis 读取
|
||||
if cache.enabled:
|
||||
cached_value = cache.get(key)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
# Redis 未命中,从内存缓存读取
|
||||
local_key = (fn.__name__, args, tuple(sorted(kwargs.items())))
|
||||
if local_key in local:
|
||||
return local[local_key]
|
||||
|
||||
# 执行函数
|
||||
val = fn(*args, **kwargs)
|
||||
local[key] = val
|
||||
|
||||
# 写入 Redis
|
||||
if cache.enabled:
|
||||
cache.set(key, val, expire=ttl)
|
||||
|
||||
# 写入内存缓存(降级)
|
||||
local[local_key] = val
|
||||
return val
|
||||
|
||||
return wrapper
|
||||
@@ -183,7 +203,19 @@ def get_stock_news(code: str, limit: int = 12):
|
||||
return {"source": "mock", "list": []}
|
||||
|
||||
|
||||
# 已知指数代码 → 新浪前缀映射
|
||||
_INDEX_CODES = {"000001", "000300", "000016", "399001", "399006", "899050"}
|
||||
|
||||
def _is_index(code: str) -> bool:
|
||||
return code in _INDEX_CODES or code.startswith(("sh0", "sz3990", "bj8990"))
|
||||
|
||||
def _sina_symbol(code: str) -> str:
|
||||
if code in ("000001", "000016"): # 上证系列
|
||||
return "sh" + code
|
||||
if code in ("000300",): # 沪深300
|
||||
return "sh" + code
|
||||
if code in ("399001", "399006"): # 深证
|
||||
return "sz" + code
|
||||
if code.startswith("6"):
|
||||
return "sh" + code
|
||||
if code.startswith(("0", "3")):
|
||||
@@ -194,9 +226,23 @@ def _sina_symbol(code: str) -> str:
|
||||
|
||||
|
||||
@cached(60)
|
||||
def get_kline(symbol: str = "600519", days: int = 120):
|
||||
def get_kline(symbol: str = "000001", days: int = 120):
|
||||
if AK_OK:
|
||||
# 主源:新浪日线(更稳定);备源:腾讯
|
||||
# 指数走专用接口
|
||||
if symbol in _INDEX_CODES:
|
||||
try:
|
||||
sym = _sina_symbol(symbol)
|
||||
df = ak.stock_zh_index_daily(symbol=sym)
|
||||
if df is not None and not df.empty:
|
||||
df = df.tail(days)
|
||||
dates = [str(d)[5:].replace("-", "/") for d in df["date"]]
|
||||
ohlc = [[float(r["open"]), float(r["close"]), float(r["low"]), float(r["high"])]
|
||||
for _, r in df.iterrows()]
|
||||
vols = [int(r["volume"]) if "volume" in df.columns else 0 for _, r in df.iterrows()]
|
||||
return {"source": "akshare", "symbol": symbol, "dates": dates, "ohlc": ohlc, "vols": vols}
|
||||
except Exception:
|
||||
pass
|
||||
# 个股主源:新浪日线(更稳定);备源:腾讯
|
||||
for src in ("sina", "tx"):
|
||||
try:
|
||||
sym = _sina_symbol(symbol)
|
||||
@@ -321,6 +367,90 @@ def get_treemap(mode: str = "sector"):
|
||||
return {"source": boards["source"], "mode": "sector", "items": items}
|
||||
|
||||
|
||||
@cached(120)
|
||||
def get_us_treemap():
|
||||
"""美股热门板块云图(按成交额取前100只)"""
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_us_spot_em()
|
||||
if df is not None and not df.empty:
|
||||
top = df.sort_values("成交额", ascending=False).head(100)
|
||||
items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e8, 2),
|
||||
"pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()]
|
||||
items = [x for x in items if x["name"]]
|
||||
return {"source": "akshare", "market": "us", "items": items}
|
||||
except Exception:
|
||||
pass
|
||||
names = ["苹果","微软","谷歌","亚马逊","英伟达","特斯拉","Meta","台积电","巴菲特","摩根"]
|
||||
return {"source": "mock", "market": "us",
|
||||
"items": [{"name": n, "value": _rnd(10,200), "pct": round(_rnd(-4,4),2)} for n in names]}
|
||||
|
||||
|
||||
@cached(120)
|
||||
def get_hk_treemap():
|
||||
"""港股热门板块云图(按成交额取前100只)"""
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_hk_spot_em()
|
||||
if df is not None and not df.empty:
|
||||
top = df.sort_values("成交额", ascending=False).head(100)
|
||||
items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e4, 2),
|
||||
"pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()]
|
||||
items = [x for x in items if x["name"]]
|
||||
return {"source": "akshare", "market": "hk", "items": items}
|
||||
except Exception:
|
||||
pass
|
||||
names = ["腾讯","阿里巴巴","美团","京东","小米","百度","网易","中国平安","汇丰","友邦"]
|
||||
return {"source": "mock", "market": "hk",
|
||||
"items": [{"name": n, "value": _rnd(5,100), "pct": round(_rnd(-4,4),2)} for n in names]}
|
||||
|
||||
|
||||
@cached(120)
|
||||
def get_all_sector_leaders(top_n: int = 5):
|
||||
"""一次性获取所有板块的前N只龙头股"""
|
||||
boards = get_industry_boards()
|
||||
result = {}
|
||||
for b in boards.get("list", []):
|
||||
name = b["name"]
|
||||
try:
|
||||
r = get_sector_stocks(name, top_n + 1)
|
||||
result[name] = r.get("stocks", [])[:top_n]
|
||||
except Exception:
|
||||
result[name] = []
|
||||
return {"source": "akshare", "sectors": result}
|
||||
|
||||
|
||||
@cached(300)
|
||||
def get_sector_stocks(sector_name: str, limit: int = 20):
|
||||
"""获取板块成分股,按成交额排序"""
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_board_industry_cons_em(symbol=sector_name)
|
||||
if df is not None and not df.empty:
|
||||
if "成交额" in df.columns:
|
||||
df = df.sort_values("成交额", ascending=False)
|
||||
stocks = []
|
||||
for _, r in df.head(limit).iterrows():
|
||||
try:
|
||||
stocks.append({
|
||||
"code": str(r.get("代码", "")),
|
||||
"name": str(r.get("名称", "")),
|
||||
"pct": round(float(r.get("涨跌幅", 0)), 2),
|
||||
"price": round(float(r.get("最新价", 0)), 2),
|
||||
"amount": round(float(r.get("成交额", 0)) / 1e8, 2),
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
return {"source": "akshare", "name": sector_name, "stocks": stocks}
|
||||
except Exception:
|
||||
pass
|
||||
# mock
|
||||
stocks = [{"code": f"60000{i}", "name": f"{sector_name}{i+1}",
|
||||
"pct": round(_rnd(-5, 5), 2), "price": round(_rnd(5, 100), 2), "amount": round(_rnd(1, 50), 2)}
|
||||
for i in range(10)]
|
||||
return {"source": "mock", "name": sector_name, "stocks": stocks}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 资金流向(行业)
|
||||
# ============================================================
|
||||
|
||||
88
backend/auth.py
Normal file
88
backend/auth.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from fastapi import Depends, HTTPException, status, Header
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError, jwt
|
||||
import bcrypt
|
||||
from sqlalchemy.orm import Session
|
||||
from db import SessionLocal
|
||||
from models import User
|
||||
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, API_KEYS
|
||||
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
"""生成 JWT Token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||||
"""验证用户名密码"""
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
db: Session = Depends(lambda: SessionLocal())
|
||||
) -> Optional[User]:
|
||||
"""获取当前用户(支持 JWT Token 和 API Key 两种方式)"""
|
||||
# 方式1:API Key
|
||||
if x_api_key and x_api_key in API_KEYS:
|
||||
# API Key 模式,返回虚拟管理员
|
||||
user = db.query(User).filter(User.username == "admin").first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
# 方式2:JWT Token
|
||||
if credentials:
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
return None
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
return user
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def require_auth(current_user: Optional[User] = Depends(get_current_user)):
|
||||
"""需要认证的依赖"""
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未认证,请先登录",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def require_admin(current_user: User = Depends(require_auth)):
|
||||
"""需要管理员权限的依赖"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="需要管理员权限"
|
||||
)
|
||||
return current_user
|
||||
@@ -9,12 +9,16 @@ import sys
|
||||
|
||||
from db import init_db
|
||||
import ingest
|
||||
import init_auth
|
||||
import watchlist_manager as wl
|
||||
|
||||
|
||||
def main():
|
||||
init_db()
|
||||
args = sys.argv[1:]
|
||||
if not args or args[0] == "init":
|
||||
init_auth.init_default_admin()
|
||||
wl.init_default_groups()
|
||||
print("init done")
|
||||
return
|
||||
if args[0] == "ingest":
|
||||
|
||||
@@ -51,3 +51,19 @@ SERVERCHAN_KEY = os.getenv("SERVERCHAN_KEY", "")
|
||||
WECOM_WEBHOOK = os.getenv("WECOM_WEBHOOK", "")
|
||||
# PushPlus(微信推送)
|
||||
PUSHPLUS_TOKEN = os.getenv("PUSHPLUS_TOKEN", "")
|
||||
|
||||
# ---- Redis 缓存 ----
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB = int(os.getenv("REDIS_DB", "0"))
|
||||
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
# ---- 鉴权配置 ----
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "10080")) # 7天
|
||||
# API Key 模式(可选,用于外部调用)
|
||||
API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []
|
||||
# 默认管理员账号(首次启动时创建)
|
||||
DEFAULT_ADMIN_USERNAME = os.getenv("DEFAULT_ADMIN_USERNAME", "admin")
|
||||
DEFAULT_ADMIN_PASSWORD = os.getenv("DEFAULT_ADMIN_PASSWORD", "admin123")
|
||||
|
||||
398
backend/data_manager.py
Normal file
398
backend/data_manager.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""数据修正与回填增强:数据修正、断点续传、完整性检查、质量报告"""
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
from sqlalchemy import select, func, and_, delete
|
||||
from db import get_session
|
||||
from models import DailyQuote, StockMetric, Security, IndexDaily, SectorDaily, JobRun
|
||||
import ingest
|
||||
|
||||
# 回填进度文件路径
|
||||
PROGRESS_FILE = os.path.join(os.path.dirname(__file__), ".refill_progress.json")
|
||||
|
||||
|
||||
# ============ 数据修正 ============
|
||||
|
||||
def delete_quote(code: str, date: str) -> Dict:
|
||||
"""删除指定股票指定日期的日线数据"""
|
||||
d = dt.date.fromisoformat(date)
|
||||
with get_session() as s:
|
||||
row = s.execute(
|
||||
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
|
||||
).scalar_one_or_none()
|
||||
if not row:
|
||||
return {"ok": False, "msg": f"{code} {date} 无此数据"}
|
||||
s.delete(row)
|
||||
s.commit()
|
||||
return {"ok": True, "msg": f"已删除 {code} {date} 日线"}
|
||||
|
||||
|
||||
def update_quote(code: str, date: str, fields: Dict) -> Dict:
|
||||
"""修正指定股票指定日期的日线数据"""
|
||||
allowed = {"open", "high", "low", "close", "volume", "amount"}
|
||||
to_update = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not to_update:
|
||||
return {"ok": False, "msg": "无有效修正字段"}
|
||||
|
||||
d = dt.date.fromisoformat(date)
|
||||
with get_session() as s:
|
||||
row = s.execute(
|
||||
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
|
||||
).scalar_one_or_none()
|
||||
if not row:
|
||||
return {"ok": False, "msg": f"{code} {date} 无此数据"}
|
||||
for k, v in to_update.items():
|
||||
setattr(row, k, v)
|
||||
s.commit()
|
||||
return {"ok": True, "updated": to_update}
|
||||
|
||||
|
||||
def delete_quotes_range(code: str, start: str, end: str) -> Dict:
|
||||
"""删除指定股票日期范围内的日线数据"""
|
||||
d_start = dt.date.fromisoformat(start)
|
||||
d_end = dt.date.fromisoformat(end)
|
||||
with get_session() as s:
|
||||
rows = s.execute(
|
||||
select(DailyQuote).where(
|
||||
DailyQuote.code == code,
|
||||
DailyQuote.date >= d_start,
|
||||
DailyQuote.date <= d_end
|
||||
)
|
||||
).scalars().all()
|
||||
count = len(rows)
|
||||
for row in rows:
|
||||
s.delete(row)
|
||||
s.commit()
|
||||
return {"ok": True, "deleted": count, "range": f"{start} ~ {end}"}
|
||||
|
||||
|
||||
def refetch_quote(code: str, days: int = 30) -> Dict:
|
||||
"""重新抓取指定股票的日线数据(覆盖更新)"""
|
||||
rows = ingest.fetch_daily(code, days)
|
||||
if not rows:
|
||||
return {"ok": False, "msg": f"抓取 {code} 数据失败"}
|
||||
n = ingest.ingest_quotes([code], days=days)
|
||||
return {"ok": True, "code": code, "rows": len(rows), "msg": f"已更新 {len(rows)} 条日线"}
|
||||
|
||||
|
||||
# ============ 数据完整性检查 ============
|
||||
|
||||
def check_data_integrity(codes: Optional[List[str]] = None, days: int = 30) -> Dict:
|
||||
"""检查数据完整性,找出缺失数据的股票和日期"""
|
||||
with get_session() as s:
|
||||
# 确定检查范围
|
||||
latest = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||||
if not latest:
|
||||
return {"ok": False, "msg": "数据库无日线数据"}
|
||||
|
||||
start = latest - dt.timedelta(days=days)
|
||||
|
||||
# 获取检查的股票列表
|
||||
if codes:
|
||||
check_codes = codes
|
||||
else:
|
||||
# 默认检查有记录的所有股票
|
||||
all_codes = s.execute(
|
||||
select(DailyQuote.code).where(
|
||||
DailyQuote.date >= start
|
||||
).distinct()
|
||||
).scalars().all()
|
||||
check_codes = list(all_codes)[:200] # 最多检查200只
|
||||
|
||||
# 统计每只股票的数据点数
|
||||
from sqlalchemy import case
|
||||
code_counts = {}
|
||||
for code in check_codes:
|
||||
count = s.execute(
|
||||
select(func.count()).select_from(DailyQuote)
|
||||
.where(DailyQuote.code == code, DailyQuote.date >= start)
|
||||
).scalar()
|
||||
code_counts[code] = count
|
||||
|
||||
# 以最多数据量为基准(应是交易日数)
|
||||
expected = max(code_counts.values()) if code_counts else 0
|
||||
|
||||
# 找出缺失数据的股票
|
||||
missing = []
|
||||
normal = []
|
||||
for code, count in code_counts.items():
|
||||
ratio = count / expected if expected > 0 else 0
|
||||
if ratio < 0.8: # 缺失超过20%
|
||||
with get_session() as s2:
|
||||
sec = s2.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
missing.append({
|
||||
"code": code,
|
||||
"name": name,
|
||||
"actual": count,
|
||||
"expected": expected,
|
||||
"missing": expected - count,
|
||||
"missing_pct": round((1 - ratio) * 100, 1)
|
||||
})
|
||||
else:
|
||||
normal.append(code)
|
||||
|
||||
missing.sort(key=lambda x: x["missing"], reverse=True)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"check_range": f"{start.isoformat()} ~ {latest.isoformat()}",
|
||||
"checked": len(check_codes),
|
||||
"expected_days": expected,
|
||||
"normal_count": len(normal),
|
||||
"missing_count": len(missing),
|
||||
"missing_stocks": missing[:50]
|
||||
}
|
||||
|
||||
|
||||
def auto_fix_missing(limit: int = 50) -> Dict:
|
||||
"""自动补齐缺失数据(批量重新抓取)"""
|
||||
result = check_data_integrity(days=30)
|
||||
if not result["ok"] or result["missing_count"] == 0:
|
||||
return {"ok": True, "msg": "数据完整,无需修复", "fixed": 0}
|
||||
|
||||
missing_stocks = result["missing_stocks"][:limit]
|
||||
codes = [s["code"] for s in missing_stocks]
|
||||
|
||||
with get_session() as s:
|
||||
job = JobRun(job="auto_fix", status="running",
|
||||
message=f"0/{len(codes)}")
|
||||
s.add(job)
|
||||
s.commit()
|
||||
job_id = job.id
|
||||
|
||||
fixed = 0
|
||||
failed = []
|
||||
try:
|
||||
for i, code in enumerate(codes):
|
||||
rows = ingest.fetch_daily(code, days=60)
|
||||
if rows:
|
||||
ingest.ingest_quotes([code], days=60)
|
||||
fixed += 1
|
||||
else:
|
||||
failed.append(code)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
with get_session() as s:
|
||||
j = s.get(JobRun, job_id)
|
||||
j.message = f"{i+1}/{len(codes)}"
|
||||
s.commit()
|
||||
|
||||
status = "success"
|
||||
msg = f"修复 {fixed}/{len(codes)},失败 {len(failed)} 只"
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
msg = f"修复中断: {repr(e)[:160]}"
|
||||
|
||||
with get_session() as s:
|
||||
j = s.get(JobRun, job_id)
|
||||
j.status = status
|
||||
j.finished_at = dt.datetime.now()
|
||||
j.message = msg
|
||||
s.commit()
|
||||
|
||||
return {"ok": True, "fixed": fixed, "failed": failed, "msg": msg}
|
||||
|
||||
|
||||
# ============ 断点续传回填 ============
|
||||
|
||||
def _load_progress() -> Dict:
|
||||
"""加载回填进度"""
|
||||
if os.path.exists(PROGRESS_FILE):
|
||||
try:
|
||||
with open(PROGRESS_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_progress(progress: Dict):
|
||||
"""保存回填进度"""
|
||||
with open(PROGRESS_FILE, "w") as f:
|
||||
json.dump(progress, f)
|
||||
|
||||
|
||||
def _clear_progress(task_id: str):
|
||||
"""清除指定任务的进度"""
|
||||
progress = _load_progress()
|
||||
progress.pop(task_id, None)
|
||||
_save_progress(progress)
|
||||
|
||||
|
||||
def start_refill_with_resume(days: int = 250, task_id: str = "default") -> Dict:
|
||||
"""带断点续传的全市场回填"""
|
||||
from akshare_service import _code_name_map
|
||||
|
||||
cmap = _code_name_map()
|
||||
all_codes = [c for c in cmap.keys() if c[:1] in ("0", "3", "6")]
|
||||
total = len(all_codes)
|
||||
|
||||
# 加载进度
|
||||
progress = _load_progress()
|
||||
task_progress = progress.get(task_id, {"done_codes": [], "days": days})
|
||||
done_codes = set(task_progress.get("done_codes", []))
|
||||
|
||||
# 过滤已完成的股票
|
||||
remaining = [c for c in all_codes if c not in done_codes]
|
||||
|
||||
with get_session() as s:
|
||||
job = JobRun(
|
||||
job="refill_resume",
|
||||
status="running",
|
||||
message=f"续传: 已完成 {len(done_codes)}/{total},剩余 {len(remaining)}"
|
||||
)
|
||||
s.add(job)
|
||||
s.commit()
|
||||
job_id = job.id
|
||||
|
||||
fixed = len(done_codes)
|
||||
try:
|
||||
for i in range(0, len(remaining), 50):
|
||||
batch = remaining[i:i + 50]
|
||||
ingest.ingest_quotes(batch, days=days, with_metrics=True, cmap=cmap)
|
||||
fixed += len(batch)
|
||||
|
||||
# 保存进度
|
||||
done_codes.update(batch)
|
||||
progress[task_id] = {
|
||||
"done_codes": list(done_codes),
|
||||
"days": days,
|
||||
"updated_at": dt.datetime.now().isoformat()
|
||||
}
|
||||
_save_progress(progress)
|
||||
|
||||
with get_session() as s:
|
||||
j = s.get(JobRun, job_id)
|
||||
j.message = f"{fixed}/{total}"
|
||||
s.commit()
|
||||
|
||||
# 完成后清除进度
|
||||
_clear_progress(task_id)
|
||||
status = "success"
|
||||
msg = f"完成 {fixed}/{total}"
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
msg = f"中断于 {fixed}/{total} | {repr(e)[:160]}"
|
||||
# 保留进度供续传
|
||||
|
||||
with get_session() as s:
|
||||
j = s.get(JobRun, job_id)
|
||||
j.status = status
|
||||
j.finished_at = dt.datetime.now()
|
||||
j.message = msg
|
||||
s.commit()
|
||||
|
||||
return {"ok": status == "success", "done": fixed, "total": total, "msg": msg}
|
||||
|
||||
|
||||
def get_refill_progress(task_id: str = "default") -> Dict:
|
||||
"""获取回填进度"""
|
||||
progress = _load_progress()
|
||||
task = progress.get(task_id)
|
||||
if not task:
|
||||
return {"ok": True, "has_progress": False, "msg": "无回填进度记录"}
|
||||
|
||||
from akshare_service import _code_name_map
|
||||
cmap = _code_name_map()
|
||||
total = len([c for c in cmap.keys() if c[:1] in ("0", "3", "6")])
|
||||
done = len(task.get("done_codes", []))
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"has_progress": True,
|
||||
"task_id": task_id,
|
||||
"done": done,
|
||||
"total": total,
|
||||
"pct": round(done / total * 100, 1) if total > 0 else 0,
|
||||
"updated_at": task.get("updated_at", "")
|
||||
}
|
||||
|
||||
|
||||
def clear_refill_progress(task_id: str = "default") -> Dict:
|
||||
"""清除回填进度(从头开始)"""
|
||||
_clear_progress(task_id)
|
||||
return {"ok": True, "msg": f"已清除任务 {task_id} 的进度"}
|
||||
|
||||
|
||||
# ============ 数据质量报告 ============
|
||||
|
||||
def get_data_quality_report() -> Dict:
|
||||
"""生成数据质量报告"""
|
||||
with get_session() as s:
|
||||
# 基本统计
|
||||
total_quotes = s.execute(select(func.count()).select_from(DailyQuote)).scalar() or 0
|
||||
total_stocks = s.execute(
|
||||
select(func.count(DailyQuote.code.distinct()))
|
||||
).scalar() or 0
|
||||
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||||
earliest_date = s.execute(select(func.min(DailyQuote.date))).scalar()
|
||||
|
||||
# 最近30天数据密度
|
||||
if latest_date:
|
||||
start30 = latest_date - dt.timedelta(days=30)
|
||||
recent_stocks = s.execute(
|
||||
select(func.count(DailyQuote.code.distinct()))
|
||||
.where(DailyQuote.date >= start30)
|
||||
).scalar() or 0
|
||||
|
||||
recent_dates = s.execute(
|
||||
select(func.count(DailyQuote.date.distinct()))
|
||||
.where(DailyQuote.date >= start30)
|
||||
).scalar() or 0
|
||||
else:
|
||||
recent_stocks = 0
|
||||
recent_dates = 0
|
||||
|
||||
# 异常数据检测(开盘价为0的记录)
|
||||
zero_open = s.execute(
|
||||
select(func.count()).select_from(DailyQuote)
|
||||
.where(DailyQuote.open == 0)
|
||||
).scalar() or 0
|
||||
|
||||
# 最近任务状态
|
||||
recent_jobs = s.execute(
|
||||
select(JobRun).order_by(JobRun.id.desc()).limit(5)
|
||||
).scalars().all()
|
||||
|
||||
jobs_summary = [{
|
||||
"job": j.job,
|
||||
"status": j.status,
|
||||
"started": j.started_at.strftime("%m-%d %H:%M") if j.started_at else "",
|
||||
"message": j.message[:100]
|
||||
} for j in recent_jobs]
|
||||
|
||||
# 数据健康度评分
|
||||
score = 100
|
||||
issues = []
|
||||
|
||||
if zero_open > 0:
|
||||
score -= min(20, zero_open // 100)
|
||||
issues.append(f"存在 {zero_open} 条开盘价为0的异常数据")
|
||||
|
||||
if total_stocks < 100:
|
||||
score -= 30
|
||||
issues.append(f"入库股票数量偏少({total_stocks}只)")
|
||||
|
||||
if latest_date and (dt.date.today() - latest_date).days > 7:
|
||||
score -= 20
|
||||
issues.append(f"数据滞后 {(dt.date.today() - latest_date).days} 天")
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"generated_at": dt.datetime.now().isoformat(),
|
||||
"health_score": max(0, score),
|
||||
"issues": issues,
|
||||
"statistics": {
|
||||
"total_quotes": total_quotes,
|
||||
"total_stocks": total_stocks,
|
||||
"latest_date": latest_date.isoformat() if latest_date else None,
|
||||
"earliest_date": earliest_date.isoformat() if earliest_date else None,
|
||||
"data_span_days": (latest_date - earliest_date).days if latest_date and earliest_date else 0,
|
||||
"recent_30d_stocks": recent_stocks,
|
||||
"recent_30d_dates": recent_dates,
|
||||
"zero_open_count": zero_open
|
||||
},
|
||||
"recent_jobs": jobs_summary
|
||||
}
|
||||
73
backend/exceptions.py
Normal file
73
backend/exceptions.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import traceback
|
||||
|
||||
class BusinessException(Exception):
|
||||
"""业务异常基类"""
|
||||
def __init__(self, message: str, code: int = 400):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(self.message)
|
||||
|
||||
class DataSourceException(BusinessException):
|
||||
"""数据源异常(AkShare等)"""
|
||||
def __init__(self, message: str = "数据源异常,请稍后重试"):
|
||||
super().__init__(message, code=503)
|
||||
|
||||
class AuthException(BusinessException):
|
||||
"""认证异常"""
|
||||
def __init__(self, message: str = "认证失败"):
|
||||
super().__init__(message, code=401)
|
||||
|
||||
class PermissionException(BusinessException):
|
||||
"""权限异常"""
|
||||
def __init__(self, message: str = "权限不足"):
|
||||
super().__init__(message, code=403)
|
||||
|
||||
async def business_exception_handler(request: Request, exc: BusinessException):
|
||||
"""业务异常处理器"""
|
||||
return JSONResponse(
|
||||
status_code=exc.code,
|
||||
content={
|
||||
"success": False,
|
||||
"error": exc.message,
|
||||
"code": exc.code
|
||||
}
|
||||
)
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""请求参数验证异常处理器"""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "请求参数错误",
|
||||
"detail": exc.errors()
|
||||
}
|
||||
)
|
||||
|
||||
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
|
||||
"""数据库异常处理器"""
|
||||
print(f"数据库错误: {exc}")
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "数据库操作失败"
|
||||
}
|
||||
)
|
||||
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""通用异常处理器"""
|
||||
print(f"未捕获异常: {type(exc).__name__}: {exc}")
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "服务器内部错误"
|
||||
}
|
||||
)
|
||||
@@ -12,7 +12,7 @@ import akshare_service as svc
|
||||
import config
|
||||
from db import get_session
|
||||
from models import (DailyQuote, DragonTiger, FundFlowDaily, IndexDaily, JobRun,
|
||||
SectorDaily, Security, SentimentDaily, StockMetric)
|
||||
SectorDaily, SectorLeader, Security, SentimentDaily, StockMetric)
|
||||
|
||||
try:
|
||||
import akshare as ak
|
||||
@@ -229,6 +229,25 @@ def ingest_sectors():
|
||||
return n
|
||||
|
||||
|
||||
def ingest_sector_leaders():
|
||||
"""入库各板块前5龙头股(按成交额)"""
|
||||
d = _today()
|
||||
data = svc.get_all_sector_leaders(top_n=5)
|
||||
rows = []
|
||||
for sector, stocks in data.get("sectors", {}).items():
|
||||
for i, s in enumerate(stocks):
|
||||
rows.append({"date": d, "sector": sector, "code": s["code"],
|
||||
"name": s["name"], "pct": s["pct"],
|
||||
"price": s["price"], "amount": s["amount"], "rank": i + 1})
|
||||
if not rows:
|
||||
return 0
|
||||
with get_session() as s:
|
||||
n = _upsert(s, SectorLeader, rows, ["date", "sector", "code"],
|
||||
["name", "pct", "price", "amount", "rank"])
|
||||
s.commit()
|
||||
return n
|
||||
|
||||
|
||||
def ingest_fund_flow():
|
||||
data = svc.get_fund_flow()
|
||||
d = _today()
|
||||
@@ -280,6 +299,7 @@ def run_daily_ingest(universe=None, with_quotes=True):
|
||||
summary["securities"] = ingest_securities()
|
||||
summary["indices"] = ingest_indices()
|
||||
summary["sectors"] = ingest_sectors()
|
||||
summary["sector_leaders"] = ingest_sector_leaders()
|
||||
summary["fund_flow"] = ingest_fund_flow()
|
||||
summary["sentiment"] = ingest_sentiment()
|
||||
summary["dragon"] = ingest_dragon()
|
||||
|
||||
21
backend/init_auth.py
Normal file
21
backend/init_auth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from db import get_session
|
||||
from models import User
|
||||
from auth import get_password_hash
|
||||
from config import DEFAULT_ADMIN_USERNAME, DEFAULT_ADMIN_PASSWORD
|
||||
|
||||
def init_default_admin():
|
||||
"""创建默认管理员账号(如果不存在)"""
|
||||
with get_session() as s:
|
||||
admin = s.query(User).filter(User.username == DEFAULT_ADMIN_USERNAME).first()
|
||||
if not admin:
|
||||
admin = User(
|
||||
username=DEFAULT_ADMIN_USERNAME,
|
||||
hashed_password=get_password_hash(DEFAULT_ADMIN_PASSWORD),
|
||||
is_admin=True,
|
||||
is_active=True
|
||||
)
|
||||
s.add(admin)
|
||||
s.commit()
|
||||
print(f"✓ 创建默认管理员: {DEFAULT_ADMIN_USERNAME}")
|
||||
else:
|
||||
print(f"✓ 管理员账号已存在: {DEFAULT_ADMIN_USERNAME}")
|
||||
106
backend/install.sh
Normal file
106
backend/install.sh
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/bin/bash
|
||||
# 三大核心功能快速安装脚本(WSL/Linux)
|
||||
|
||||
set -e
|
||||
|
||||
echo "========================================="
|
||||
echo " Blackdata StockTerminal 核心功能安装"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
# 检查是否在 WSL/Linux 环境
|
||||
if [[ "$OSTYPE" != "linux-gnu"* ]]; then
|
||||
echo "⚠ 此脚本仅支持 WSL/Linux 环境"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 1. 安装系统依赖
|
||||
echo "[1/6] 检查并安装系统依赖..."
|
||||
sudo apt update
|
||||
sudo apt install -y postgresql postgresql-contrib redis-server python3-pip python3-venv
|
||||
|
||||
# 2. 启动服务
|
||||
echo ""
|
||||
echo "[2/6] 启动 PostgreSQL 和 Redis..."
|
||||
sudo service postgresql start
|
||||
sudo service redis-server start
|
||||
|
||||
# 验证服务
|
||||
if redis-cli ping > /dev/null 2>&1; then
|
||||
echo "✓ Redis 运行正常"
|
||||
else
|
||||
echo "⚠ Redis 启动失败,缓存将降级到内存模式"
|
||||
fi
|
||||
|
||||
# 3. 创建虚拟环境(如果不存在)
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo ""
|
||||
echo "[3/6] 创建 Python 虚拟环境..."
|
||||
python3 -m venv .venv
|
||||
else
|
||||
echo ""
|
||||
echo "[3/6] 虚拟环境已存在,跳过创建"
|
||||
fi
|
||||
|
||||
# 4. 安装 Python 依赖
|
||||
echo ""
|
||||
echo "[4/6] 安装 Python 依赖包..."
|
||||
source .venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 5. 配置环境变量
|
||||
echo ""
|
||||
echo "[5/6] 配置环境变量..."
|
||||
if [ ! -f ".env" ]; then
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "✓ 已从 .env.example 创建 .env 文件"
|
||||
else
|
||||
echo "⚠ .env.example 不存在,请手动创建 .env 文件"
|
||||
fi
|
||||
|
||||
# 生成随机 SECRET_KEY
|
||||
SECRET_KEY=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))")
|
||||
echo ""
|
||||
echo "生成的 SECRET_KEY(请添加到 .env):"
|
||||
echo "SECRET_KEY=$SECRET_KEY"
|
||||
echo ""
|
||||
else
|
||||
echo "✓ .env 文件已存在"
|
||||
fi
|
||||
|
||||
# 6. 初始化数据库
|
||||
echo ""
|
||||
echo "[6/6] 初始化数据库..."
|
||||
|
||||
# 检查 PostgreSQL 密码配置
|
||||
if grep -q "PG_PASSWORD=your_password" .env 2>/dev/null || grep -q "PG_PASSWORD=$" .env 2>/dev/null; then
|
||||
echo ""
|
||||
echo "⚠ 请先在 .env 中设置 PostgreSQL 密码:"
|
||||
echo " 1. 设置数据库密码: sudo -u postgres psql -c \"ALTER USER postgres PASSWORD 'your_password';\""
|
||||
echo " 2. 在 .env 中配置: PG_PASSWORD=your_password"
|
||||
echo ""
|
||||
echo "配置完成后,运行: python cli.py init"
|
||||
else
|
||||
python cli.py init
|
||||
echo "✓ 数据库初始化完成"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo " 安装完成!"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
echo "下一步:"
|
||||
echo "1. 编辑 backend/.env 文件,配置数据库密码和其他选项"
|
||||
echo "2. 如果未初始化数据库,运行: python cli.py init"
|
||||
echo "3. 启动服务: python main.py"
|
||||
echo "4. 浏览器访问: http://localhost:8000"
|
||||
echo "5. 默认管理员: admin / admin123 (首次登录后务必修改密码)"
|
||||
echo "6. 测试功能: python test_core_features.py"
|
||||
echo ""
|
||||
echo "详细文档:"
|
||||
echo "- 升级指南: backend/UPGRADE_GUIDE.md"
|
||||
echo "- 配置说明: backend/ENV_CONFIG.md"
|
||||
echo ""
|
||||
@@ -1,10 +1,4 @@
|
||||
"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行。
|
||||
|
||||
功能:
|
||||
1. 连板股追踪器
|
||||
2. 炸板率统计
|
||||
3. 涨停敢死队排行
|
||||
"""
|
||||
"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行、炸板走势统计、涨停原因分类。"""
|
||||
import datetime as dt
|
||||
from typing import List, Dict, Any, Optional
|
||||
from collections import defaultdict, Counter
|
||||
@@ -12,7 +6,24 @@ import numpy as np
|
||||
from sqlalchemy import select, and_, func, desc
|
||||
|
||||
from db import get_session
|
||||
from models import DailyQuote, StockMetric
|
||||
from models import DailyQuote, StockMetric, DragonTiger
|
||||
|
||||
try:
|
||||
import akshare as ak
|
||||
AK_OK = True
|
||||
except Exception:
|
||||
ak = None
|
||||
AK_OK = False
|
||||
|
||||
# 涨停原因关键词分类
|
||||
LIMIT_REASON_MAP = {
|
||||
"题材": ["概念", "题材", "热点", "风口", "赛道"],
|
||||
"业绩": ["业绩", "净利润", "营收", "盈利", "超预期", "预增", "扭亏"],
|
||||
"政策": ["政策", "补贴", "利好", "支持", "规划", "国家", "工信部", "发改委"],
|
||||
"技术突破": ["突破", "新高", "均线", "金叉", "放量"],
|
||||
"重组并购": ["重组", "并购", "收购", "合并", "入股"],
|
||||
"情绪": ["跟风", "连板", "情绪", "氛围", "涨停潮"],
|
||||
}
|
||||
|
||||
|
||||
def get_limit_stocks(date: Optional[dt.date] = None, limit_type: str = "up") -> Dict[str, Any]:
|
||||
@@ -293,6 +304,214 @@ def analyze_limit_break_rate(days: int = 60) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_consecutive_calendar(days: int = 60) -> Dict[str, Any]:
|
||||
"""连板日历:记录每只股票的连板历史,分析几进几出规律"""
|
||||
with get_session() as s:
|
||||
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||||
if not latest_date:
|
||||
return {"ok": False, "msg": "暂无数据"}
|
||||
start_date = latest_date - dt.timedelta(days=days)
|
||||
quotes = s.execute(
|
||||
select(DailyQuote)
|
||||
.where(DailyQuote.date >= start_date)
|
||||
.order_by(DailyQuote.code, DailyQuote.date)
|
||||
).scalars().all()
|
||||
|
||||
stock_data = defaultdict(list)
|
||||
for q in quotes:
|
||||
stock_data[q.code].append(q)
|
||||
|
||||
all_streaks = []
|
||||
current_streaks = []
|
||||
|
||||
for code, data in stock_data.items():
|
||||
data_sorted = sorted(data, key=lambda x: x.date)
|
||||
name = data_sorted[-1].name
|
||||
streaks = []
|
||||
current = []
|
||||
|
||||
for q in data_sorted:
|
||||
if q.open == 0:
|
||||
if len(current) >= 2:
|
||||
streaks.append(current)
|
||||
current = []
|
||||
continue
|
||||
pct = (float(q.close) - float(q.open)) / float(q.open) * 100
|
||||
if pct >= 9.8:
|
||||
current.append(q)
|
||||
else:
|
||||
if len(current) >= 2:
|
||||
streaks.append(current)
|
||||
current = []
|
||||
|
||||
if len(current) >= 2:
|
||||
current_streaks.append({
|
||||
"code": code, "name": name,
|
||||
"days": len(current),
|
||||
"start_date": current[0].date.isoformat(),
|
||||
"latest_date": current[-1].date.isoformat(),
|
||||
"latest_close": float(current[-1].close)
|
||||
})
|
||||
for streak in streaks:
|
||||
all_streaks.append({
|
||||
"code": code, "name": name,
|
||||
"days": len(streak),
|
||||
"start_date": streak[0].date.isoformat(),
|
||||
"end_date": streak[-1].date.isoformat()
|
||||
})
|
||||
|
||||
distribution = defaultdict(int)
|
||||
for item in all_streaks:
|
||||
distribution[f"{item['days']}板"] += 1
|
||||
|
||||
current_streaks.sort(key=lambda x: x["days"], reverse=True)
|
||||
return {
|
||||
"ok": True,
|
||||
"date_range": f"{start_date.isoformat()} ~ {latest_date.isoformat()}",
|
||||
"current_streaks": current_streaks[:30],
|
||||
"streak_distribution": dict(distribution),
|
||||
"total_streaks": len(all_streaks)
|
||||
}
|
||||
|
||||
|
||||
def analyze_post_break_performance(days: int = 90) -> Dict[str, Any]:
|
||||
"""炸板后走势统计:炸板后 1/3/5 日表现概率分布"""
|
||||
with get_session() as s:
|
||||
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||||
if not latest_date:
|
||||
return {"ok": False, "msg": "暂无数据"}
|
||||
|
||||
start_date = latest_date - dt.timedelta(days=days + 10)
|
||||
quotes = s.execute(
|
||||
select(DailyQuote)
|
||||
.where(DailyQuote.date >= start_date)
|
||||
.order_by(DailyQuote.code, DailyQuote.date)
|
||||
).scalars().all()
|
||||
|
||||
stock_data = defaultdict(dict)
|
||||
for q in quotes:
|
||||
stock_data[q.code][q.date] = q
|
||||
|
||||
# 炸板 = 当日一度涨停但收盘时未涨停(用开高低收近似判断)
|
||||
# 简化:前一日涨停,次日收盘未涨停视为炸板
|
||||
perf_1d, perf_3d, perf_5d = [], [], []
|
||||
|
||||
for code, date_data in stock_data.items():
|
||||
dates = sorted(date_data.keys())
|
||||
for i in range(len(dates) - 5):
|
||||
today = dates[i]
|
||||
q_today = date_data[today]
|
||||
if q_today.open == 0:
|
||||
continue
|
||||
pct_today = (float(q_today.close) - float(q_today.open)) / float(q_today.open) * 100
|
||||
# 判断昨日涨停今日炸板(开高但收盘低)
|
||||
if i == 0:
|
||||
continue
|
||||
yesterday = dates[i - 1]
|
||||
q_yest = date_data[yesterday]
|
||||
if q_yest.open == 0:
|
||||
continue
|
||||
pct_yest = (float(q_yest.close) - float(q_yest.open)) / float(q_yest.open) * 100
|
||||
|
||||
# 昨日涨停,今日未涨停(炸板)
|
||||
if pct_yest >= 9.8 and pct_today < 9.8:
|
||||
base = float(q_today.close)
|
||||
# 后续 1/3/5 日表现
|
||||
for horizon, perf_list in [(1, perf_1d), (3, perf_3d), (5, perf_5d)]:
|
||||
if i + horizon < len(dates):
|
||||
future = dates[i + horizon]
|
||||
q_future = date_data[future]
|
||||
ret = (float(q_future.close) - base) / base * 100
|
||||
perf_list.append(round(ret, 2))
|
||||
|
||||
def summarize(perfs):
|
||||
if not perfs:
|
||||
return {}
|
||||
arr = np.array(perfs)
|
||||
return {
|
||||
"samples": len(perfs),
|
||||
"avg_ret": round(float(arr.mean()), 2),
|
||||
"win_rate": round(float((arr > 0).mean() * 100), 1),
|
||||
"p25": round(float(np.percentile(arr, 25)), 2),
|
||||
"median": round(float(np.median(arr)), 2),
|
||||
"p75": round(float(np.percentile(arr, 75)), 2),
|
||||
}
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"days": days,
|
||||
"after_1d": summarize(perf_1d),
|
||||
"after_3d": summarize(perf_3d),
|
||||
"after_5d": summarize(perf_5d),
|
||||
"conclusion": (
|
||||
f"炸板后样本 {len(perf_1d)} 条,"
|
||||
f"次日平均收益 {summarize(perf_1d).get('avg_ret', 0)}%,"
|
||||
f"次日上涨概率 {summarize(perf_1d).get('win_rate', 0)}%"
|
||||
) if perf_1d else "样本不足"
|
||||
}
|
||||
|
||||
|
||||
def classify_limit_reasons(date: Optional[dt.date] = None) -> Dict[str, Any]:
|
||||
"""涨停原因分类:情绪、题材、业绩、技术突破等"""
|
||||
with get_session() as s:
|
||||
if date is None:
|
||||
date = s.execute(select(func.max(DragonTiger.date))).scalar()
|
||||
if not date:
|
||||
return {"ok": False, "msg": "暂无龙虎榜数据,请先入库"}
|
||||
|
||||
lhb_rows = s.execute(
|
||||
select(DragonTiger).where(DragonTiger.date == date)
|
||||
).scalars().all()
|
||||
|
||||
# 尝试从 AkShare 获取当日涨停原因
|
||||
reason_data = {}
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_zt_pool_em(date=date.strftime("%Y%m%d"))
|
||||
if df is not None and not df.empty:
|
||||
for _, r in df.iterrows():
|
||||
code = str(r.get("代码", ""))
|
||||
reason = str(r.get("涨停原因类别", "") or r.get("上榜原因", ""))
|
||||
reason_data[code] = reason
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 合并龙虎榜原因
|
||||
for row in lhb_rows:
|
||||
if row.code not in reason_data:
|
||||
reason_data[row.code] = row.reason
|
||||
|
||||
# 分类
|
||||
classified = defaultdict(list)
|
||||
for code, reason in reason_data.items():
|
||||
matched = False
|
||||
for category, keywords in LIMIT_REASON_MAP.items():
|
||||
if any(kw in reason for kw in keywords):
|
||||
classified[category].append({"code": code, "reason": reason})
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
classified["其他"].append({"code": code, "reason": reason})
|
||||
|
||||
total = sum(len(v) for v in classified.values())
|
||||
summary = [
|
||||
{
|
||||
"category": cat,
|
||||
"count": len(items),
|
||||
"pct": round(len(items) / total * 100, 1) if total > 0 else 0,
|
||||
"stocks": items[:10]
|
||||
}
|
||||
for cat, items in sorted(classified.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
]
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"date": date.isoformat(),
|
||||
"total": total,
|
||||
"categories": summary
|
||||
}
|
||||
|
||||
|
||||
def get_limit_squad_rankings(days: int = 30, min_limits: int = 5) -> Dict[str, Any]:
|
||||
"""涨停敢死队排行
|
||||
|
||||
|
||||
518
backend/main.py
518
backend/main.py
@@ -9,14 +9,30 @@ import datetime as dt
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi import FastAPI, Query, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import select, func, desc
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import akshare_service as svc
|
||||
import redis_cache
|
||||
from redis_cache import cache
|
||||
import auth
|
||||
from auth import get_current_user, require_auth, require_admin
|
||||
import init_auth
|
||||
import exceptions
|
||||
from exceptions import (
|
||||
BusinessException,
|
||||
DataSourceException,
|
||||
business_exception_handler,
|
||||
validation_exception_handler,
|
||||
sqlalchemy_exception_handler,
|
||||
general_exception_handler
|
||||
)
|
||||
import config
|
||||
import scheduler
|
||||
import backtest as bt
|
||||
@@ -37,6 +53,11 @@ import sentiment_monitor as sentiment
|
||||
import event_driven as events
|
||||
import financial_analysis as fin
|
||||
import limit_analysis as limit_up
|
||||
import watchlist_manager as wl
|
||||
import position_cost as pc
|
||||
import trade_calendar as cal
|
||||
import data_manager as dm
|
||||
import paper_trading as paper
|
||||
from db import init_db, get_session
|
||||
from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
|
||||
SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade,
|
||||
@@ -47,8 +68,11 @@ from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
init_db()
|
||||
init_auth.init_default_admin()
|
||||
wl.init_default_groups()
|
||||
paper.ensure_default_account()
|
||||
scheduler.start_scheduler()
|
||||
print("[startup] db + scheduler ready")
|
||||
print("[startup] db + scheduler + auth ready")
|
||||
except Exception as e:
|
||||
print("[startup] WARN:", repr(e)[:160])
|
||||
yield
|
||||
@@ -56,6 +80,12 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
# 注册异常处理器
|
||||
app.add_exception_handler(BusinessException, business_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
|
||||
app.add_exception_handler(Exception, general_exception_handler)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -84,10 +114,113 @@ def save_watch(symbols):
|
||||
json.dump(symbols, f, ensure_ascii=False)
|
||||
|
||||
|
||||
# ============ 认证相关 API ============
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@app.post("/api/auth/login")
|
||||
def login(req: LoginRequest, db = Depends(get_session)):
|
||||
"""用户登录"""
|
||||
user = auth.authenticate_user(db, req.username, req.password)
|
||||
if not user:
|
||||
raise exceptions.AuthException("用户名或密码错误")
|
||||
|
||||
access_token = auth.create_access_token(data={"sub": user.username})
|
||||
return {
|
||||
"ok": True,
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"is_admin": user.is_admin
|
||||
}
|
||||
|
||||
@app.get("/api/auth/me")
|
||||
def get_me(current_user = Depends(require_auth)):
|
||||
"""获取当前用户信息"""
|
||||
return {
|
||||
"ok": True,
|
||||
"username": current_user.username,
|
||||
"is_admin": current_user.is_admin
|
||||
}
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
@app.post("/api/auth/change-password")
|
||||
def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)):
|
||||
"""修改密码"""
|
||||
if not auth.verify_password(req.old_password, current_user.hashed_password):
|
||||
raise exceptions.AuthException("原密码错误")
|
||||
|
||||
current_user.hashed_password = auth.get_password_hash(req.new_password)
|
||||
db.commit()
|
||||
return {"ok": True, "msg": "密码修改成功"}
|
||||
|
||||
# ============ 用户管理 ============
|
||||
class CreateUserRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
is_admin: bool = False
|
||||
|
||||
from models import User as UserModel
|
||||
|
||||
@app.get("/api/users")
|
||||
def list_users(current_user = Depends(require_admin)):
|
||||
with get_session() as s:
|
||||
rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all()
|
||||
return {"ok": True, "users": [{"id": r.id, "username": r.username,
|
||||
"is_admin": r.is_admin, "is_active": r.is_active,
|
||||
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]}
|
||||
|
||||
@app.post("/api/users")
|
||||
def create_user(req: CreateUserRequest, current_user = Depends(require_admin)):
|
||||
with get_session() as s:
|
||||
if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none():
|
||||
return {"ok": False, "msg": "用户名已存在"}
|
||||
user = UserModel(username=req.username,
|
||||
hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin)
|
||||
s.add(user); s.commit()
|
||||
return {"ok": True, "id": user.id}
|
||||
|
||||
@app.delete("/api/users/{uid}")
|
||||
def delete_user(uid: int, current_user = Depends(require_admin)):
|
||||
if current_user.id == uid:
|
||||
return {"ok": False, "msg": "不能删除自己"}
|
||||
with get_session() as s:
|
||||
u = s.get(UserModel, uid)
|
||||
if u: s.delete(u); s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@app.put("/api/users/{uid}/toggle_admin")
|
||||
def toggle_admin(uid: int, current_user = Depends(require_admin)):
|
||||
if current_user.id == uid:
|
||||
return {"ok": False, "msg": "不能修改自己的权限"}
|
||||
with get_session() as s:
|
||||
u = s.get(UserModel, uid)
|
||||
if not u: return {"ok": False, "msg": "用户不存在"}
|
||||
u.is_admin = not u.is_admin; s.commit()
|
||||
return {"ok": True, "is_admin": u.is_admin}
|
||||
|
||||
@app.put("/api/users/{uid}/reset_password")
|
||||
def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)):
|
||||
with get_session() as s:
|
||||
u = s.get(UserModel, uid)
|
||||
if not u: return {"ok": False, "msg": "用户不存在"}
|
||||
u.hashed_password = auth.get_password_hash(req.new_password); s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ============ API ============
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {"ok": True, "akshare": svc.AK_OK}
|
||||
return {
|
||||
"ok": True,
|
||||
"akshare": svc.AK_OK,
|
||||
"redis": cache.enabled,
|
||||
"auth": True
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/indices")
|
||||
@@ -106,10 +239,76 @@ def sentiment():
|
||||
|
||||
|
||||
@app.get("/api/treemap")
|
||||
def treemap(mode: str = Query("sector")):
|
||||
def treemap(mode: str = Query("sector"), date: str = Query(None)):
|
||||
if mode == "sector" and date:
|
||||
# 从数据库读历史板块数据
|
||||
try:
|
||||
target = dt.date.fromisoformat(date)
|
||||
except Exception:
|
||||
return svc.get_treemap(mode)
|
||||
with get_session() as s:
|
||||
rows = s.execute(
|
||||
select(SectorDaily).where(SectorDaily.date == target)
|
||||
.order_by(SectorDaily.pct.desc())
|
||||
).scalars().all()
|
||||
if not rows:
|
||||
# 找最近有数据的日期
|
||||
latest = s.execute(select(func.max(SectorDaily.date))).scalar()
|
||||
if latest:
|
||||
rows = s.execute(
|
||||
select(SectorDaily).where(SectorDaily.date == latest)
|
||||
.order_by(SectorDaily.pct.desc())
|
||||
).scalars().all()
|
||||
target = latest
|
||||
if rows:
|
||||
items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows]
|
||||
return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items}
|
||||
return svc.get_treemap(mode)
|
||||
|
||||
|
||||
@app.get("/api/treemap/us")
|
||||
def treemap_us():
|
||||
return svc.get_us_treemap()
|
||||
|
||||
|
||||
@app.get("/api/treemap/hk")
|
||||
def treemap_hk():
|
||||
return svc.get_hk_treemap()
|
||||
|
||||
|
||||
@app.get("/api/treemap/sector_stocks")
|
||||
def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)):
|
||||
return svc.get_sector_stocks(name, limit)
|
||||
|
||||
|
||||
@app.get("/api/treemap/all_leaders")
|
||||
def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)):
|
||||
from models import SectorLeader
|
||||
# 优先从数据库读
|
||||
with get_session() as s:
|
||||
target = None
|
||||
if date:
|
||||
try: target = dt.date.fromisoformat(date)
|
||||
except Exception: pass
|
||||
if not target:
|
||||
target = s.execute(select(func.max(SectorLeader.date))).scalar()
|
||||
if target:
|
||||
rows = s.execute(
|
||||
select(SectorLeader).where(SectorLeader.date == target)
|
||||
.order_by(SectorLeader.sector, SectorLeader.rank)
|
||||
).scalars().all()
|
||||
if rows:
|
||||
sectors = {}
|
||||
for r in rows:
|
||||
sectors.setdefault(r.sector, []).append({
|
||||
"code": r.code, "name": r.name, "pct": r.pct,
|
||||
"price": r.price, "amount": r.amount
|
||||
})
|
||||
return {"source": "db", "date": target.isoformat(), "sectors": sectors}
|
||||
# 降级到实时
|
||||
return svc.get_all_sector_leaders(top_n)
|
||||
|
||||
|
||||
@app.get("/api/fundflow")
|
||||
def fundflow():
|
||||
return svc.get_fund_flow()
|
||||
@@ -151,9 +350,105 @@ def watch_del(code: str):
|
||||
return {"ok": True, "list": w}
|
||||
|
||||
|
||||
# ============ 自选股分组管理 ============
|
||||
class CreateGroupRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
color: str = "blue"
|
||||
|
||||
@app.get("/api/watchlist/groups")
|
||||
def list_groups():
|
||||
"""获取所有分组"""
|
||||
return {"ok": True, "groups": wl.get_all_groups()}
|
||||
|
||||
@app.post("/api/watchlist/groups")
|
||||
def create_group(req: CreateGroupRequest):
|
||||
"""创建新分组"""
|
||||
return wl.create_group(req.name, req.description, req.color)
|
||||
|
||||
class UpdateGroupRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
@app.put("/api/watchlist/groups/{group_id}")
|
||||
def update_group(group_id: int, req: UpdateGroupRequest):
|
||||
"""更新分组信息"""
|
||||
return wl.update_group(group_id, req.name, req.description, req.color)
|
||||
|
||||
@app.delete("/api/watchlist/groups/{group_id}")
|
||||
def delete_group(group_id: int):
|
||||
"""删除分组"""
|
||||
return wl.delete_group(group_id)
|
||||
|
||||
class ReorderGroupsRequest(BaseModel):
|
||||
group_ids: List[int]
|
||||
|
||||
@app.post("/api/watchlist/groups/reorder")
|
||||
def reorder_groups(req: ReorderGroupsRequest):
|
||||
"""重新排序分组"""
|
||||
return wl.reorder_groups(req.group_ids)
|
||||
|
||||
@app.get("/api/watchlist/groups/{group_id}/stocks")
|
||||
def get_group_stocks(group_id: int, with_quotes: bool = Query(True)):
|
||||
"""获取分组内的股票"""
|
||||
return wl.get_group_stocks(group_id, with_quotes)
|
||||
|
||||
class AddStockRequest(BaseModel):
|
||||
code: str
|
||||
note: str = ""
|
||||
|
||||
@app.post("/api/watchlist/groups/{group_id}/stocks")
|
||||
def add_stock_to_group(group_id: int, req: AddStockRequest):
|
||||
"""添加股票到分组"""
|
||||
return wl.add_stock_to_group(group_id, req.code, req.note)
|
||||
|
||||
@app.delete("/api/watchlist/stocks/{item_id}")
|
||||
def remove_stock(item_id: int):
|
||||
"""从分组中移除股票"""
|
||||
return wl.remove_stock_from_group(item_id)
|
||||
|
||||
class MoveStockRequest(BaseModel):
|
||||
target_group_id: int
|
||||
|
||||
@app.post("/api/watchlist/stocks/{item_id}/move")
|
||||
def move_stock(item_id: int, req: MoveStockRequest):
|
||||
"""移动股票到另一个分组"""
|
||||
return wl.move_stock_to_group(item_id, req.target_group_id)
|
||||
|
||||
class BatchAddRequest(BaseModel):
|
||||
codes: List[str]
|
||||
|
||||
@app.post("/api/watchlist/groups/{group_id}/stocks/batch")
|
||||
def batch_add_stocks(group_id: int, req: BatchAddRequest):
|
||||
"""批量添加股票"""
|
||||
return wl.batch_add_stocks(group_id, req.codes)
|
||||
|
||||
class UpdateNoteRequest(BaseModel):
|
||||
note: str
|
||||
|
||||
@app.put("/api/watchlist/stocks/{item_id}/note")
|
||||
def update_stock_note(item_id: int, req: UpdateNoteRequest):
|
||||
"""更新股票备注"""
|
||||
return wl.update_stock_note(item_id, req.note)
|
||||
|
||||
class ReorderStocksRequest(BaseModel):
|
||||
item_ids: List[int]
|
||||
|
||||
@app.post("/api/watchlist/stocks/reorder")
|
||||
def reorder_stocks(req: ReorderStocksRequest):
|
||||
"""重新排序股票"""
|
||||
return wl.reorder_stocks(req.item_ids)
|
||||
|
||||
@app.get("/api/watchlist/search")
|
||||
def search_stocks(keyword: str = Query(..., min_length=1)):
|
||||
"""跨分组搜索股票"""
|
||||
return {"ok": True, "results": wl.search_stocks_across_groups(keyword)}
|
||||
|
||||
|
||||
# ============ 数据中台 ============
|
||||
@app.get("/api/admin/status")
|
||||
def admin_status():
|
||||
def admin_status(current_user = Depends(require_admin)):
|
||||
counts, last_dates = {}, {}
|
||||
with get_session() as s:
|
||||
for label, model in [("securities", Security), ("quotes_daily", DailyQuote),
|
||||
@@ -175,14 +470,14 @@ def admin_status():
|
||||
|
||||
|
||||
@app.post("/api/admin/ingest")
|
||||
def admin_ingest():
|
||||
def admin_ingest(current_user = Depends(require_admin)):
|
||||
if scheduler.is_running():
|
||||
return {"started": False, "msg": "已有入库任务在执行"}
|
||||
return scheduler.trigger_async()
|
||||
|
||||
|
||||
@app.post("/api/admin/ingest_all")
|
||||
def admin_ingest_all():
|
||||
def admin_ingest_all(current_user = Depends(require_admin)):
|
||||
return scheduler.trigger_all_async()
|
||||
|
||||
|
||||
@@ -488,6 +783,16 @@ def ai_today():
|
||||
return ai.today_strategy()
|
||||
|
||||
|
||||
@app.get("/api/ai/trend_analysis")
|
||||
def ai_trend_analysis(
|
||||
symbol: str = Query(...),
|
||||
date: str = Query(""),
|
||||
period: str = Query("daily")
|
||||
):
|
||||
"""走势分析:右键K线条形时调用,分析暴涨/暴跌原因"""
|
||||
return ai.trend_analysis(symbol, date, period)
|
||||
|
||||
|
||||
# ============ 可回溯:信号历史胜率 + 实测准确率 ============
|
||||
@app.get("/api/ai/signal_stats")
|
||||
def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)):
|
||||
@@ -582,6 +887,144 @@ def portfolio_equity():
|
||||
return pf.equity_curve()
|
||||
|
||||
|
||||
# ============ 持仓成本可视化增强 ============
|
||||
@app.get("/api/portfolio/cost_line/{code}")
|
||||
def get_cost_line(code: str):
|
||||
"""获取个股持仓成本线(用于K线图标注)"""
|
||||
return pc.get_position_cost_lines(code)
|
||||
|
||||
@app.get("/api/portfolio/cost_distribution")
|
||||
def get_cost_distribution():
|
||||
"""获取持仓成本分布(盈亏区间图)"""
|
||||
return pc.get_position_cost_distribution()
|
||||
|
||||
class EstimateCostRequest(BaseModel):
|
||||
code: str
|
||||
price: float
|
||||
qty: int
|
||||
side: str = "buy"
|
||||
|
||||
@app.post("/api/portfolio/estimate_cost")
|
||||
def estimate_cost(req: EstimateCostRequest):
|
||||
"""估算交易成本(下单前预估)"""
|
||||
return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side)
|
||||
|
||||
@app.get("/api/portfolio/cost_breakdown/{code}")
|
||||
def get_cost_breakdown(code: str):
|
||||
"""获取持仓的详细成本拆解"""
|
||||
return pc.get_cost_breakdown_for_position(code)
|
||||
|
||||
|
||||
# ============ 交易日历与关键事件 ============
|
||||
@app.get("/api/calendar/events")
|
||||
def calendar_events(days: int = Query(30, ge=7, le=90)):
|
||||
"""获取所有即将到来的关键事件(综合视图)"""
|
||||
return cal.get_all_upcoming_events(days)
|
||||
|
||||
@app.get("/api/calendar/dividends")
|
||||
def calendar_dividends(days: int = Query(30, ge=7, le=90)):
|
||||
"""除权除息日历(持仓股优先)"""
|
||||
return cal.get_upcoming_dividends(days)
|
||||
|
||||
@app.get("/api/calendar/unlock")
|
||||
def calendar_unlock(days: int = Query(90, ge=7, le=180)):
|
||||
"""限售解禁日历"""
|
||||
return cal.get_unlock_calendar(days)
|
||||
|
||||
@app.get("/api/calendar/earnings")
|
||||
def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)):
|
||||
"""财报披露日历"""
|
||||
return cal.get_earnings_calendar(days, holding_only)
|
||||
|
||||
@app.post("/api/calendar/check_alerts")
|
||||
def calendar_check_alerts(current_user = Depends(require_admin)):
|
||||
"""手动触发日历事件预警推送"""
|
||||
return cal.check_and_push_calendar_alerts()
|
||||
|
||||
|
||||
# ============ 数据修正与回填增强 ============
|
||||
class UpdateQuoteRequest(BaseModel):
|
||||
open: Optional[float] = None
|
||||
high: Optional[float] = None
|
||||
low: Optional[float] = None
|
||||
close: Optional[float] = None
|
||||
volume: Optional[int] = None
|
||||
amount: Optional[float] = None
|
||||
|
||||
@app.delete("/api/data/quote/{code}/{date}")
|
||||
def delete_quote(code: str, date: str, current_user = Depends(require_admin)):
|
||||
"""删除指定股票指定日期的日线"""
|
||||
return dm.delete_quote(code, date)
|
||||
|
||||
@app.put("/api/data/quote/{code}/{date}")
|
||||
def update_quote(code: str, date: str, req: UpdateQuoteRequest,
|
||||
current_user = Depends(require_admin)):
|
||||
"""修正指定日线数据"""
|
||||
return dm.update_quote(code, date, req.model_dump(exclude_none=True))
|
||||
|
||||
class DeleteRangeRequest(BaseModel):
|
||||
start: str
|
||||
end: str
|
||||
|
||||
@app.delete("/api/data/quotes/{code}/range")
|
||||
def delete_quotes_range(code: str, req: DeleteRangeRequest,
|
||||
current_user = Depends(require_admin)):
|
||||
"""删除指定股票日期范围内的日线数据"""
|
||||
return dm.delete_quotes_range(code, req.start, req.end)
|
||||
|
||||
@app.post("/api/data/refetch/{code}")
|
||||
def refetch_quote(code: str, days: int = Query(60, ge=5, le=500),
|
||||
current_user = Depends(require_admin)):
|
||||
"""重新抓取指定股票日线(覆盖更新)"""
|
||||
return dm.refetch_quote(code, days)
|
||||
|
||||
@app.get("/api/data/integrity")
|
||||
def check_integrity(days: int = Query(30, ge=7, le=90),
|
||||
current_user = Depends(require_admin)):
|
||||
"""数据完整性检查"""
|
||||
return dm.check_data_integrity(days=days)
|
||||
|
||||
@app.post("/api/data/auto_fix")
|
||||
def auto_fix_missing(limit: int = Query(50, ge=10, le=200),
|
||||
current_user = Depends(require_admin)):
|
||||
"""自动补齐缺失数据"""
|
||||
t = __import__("threading").Thread(
|
||||
target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True
|
||||
)
|
||||
t.start()
|
||||
return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"}
|
||||
|
||||
@app.get("/api/data/refill_progress")
|
||||
def refill_progress(task_id: str = Query("default")):
|
||||
"""获取回填进度"""
|
||||
return dm.get_refill_progress(task_id)
|
||||
|
||||
@app.post("/api/data/refill_resume")
|
||||
def refill_resume(days: int = Query(250, ge=30, le=1000),
|
||||
task_id: str = Query("default"),
|
||||
current_user = Depends(require_admin)):
|
||||
"""带断点续传的全市场回填(后台执行)"""
|
||||
import threading
|
||||
t = threading.Thread(
|
||||
target=dm.start_refill_with_resume,
|
||||
kwargs={"days": days, "task_id": task_id},
|
||||
daemon=True
|
||||
)
|
||||
t.start()
|
||||
return {"ok": True, "msg": f"已启动断点续传回填,天数={days},任务ID={task_id}"}
|
||||
|
||||
@app.delete("/api/data/refill_progress")
|
||||
def clear_refill_progress(task_id: str = Query("default"),
|
||||
current_user = Depends(require_admin)):
|
||||
"""清除回填进度(从头开始)"""
|
||||
return dm.clear_refill_progress(task_id)
|
||||
|
||||
@app.get("/api/data/quality_report")
|
||||
def data_quality_report(current_user = Depends(require_admin)):
|
||||
"""数据质量报告"""
|
||||
return dm.get_data_quality_report()
|
||||
|
||||
|
||||
@app.get("/api/portfolio/attribution")
|
||||
def portfolio_attribution():
|
||||
"""持仓归因分析"""
|
||||
@@ -761,6 +1204,22 @@ def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5,
|
||||
"""涨停敢死队排行"""
|
||||
return limit_up.get_limit_squad_rankings(days, min_limits)
|
||||
|
||||
@app.get("/api/limit/consecutive_calendar")
|
||||
def consecutive_calendar(days: int = Query(60, ge=20, le=120)):
|
||||
"""连板日历:记录连板历史,分析几进几出规律"""
|
||||
return limit_up.get_consecutive_calendar(days)
|
||||
|
||||
@app.get("/api/limit/post_break")
|
||||
def post_break_performance(days: int = Query(90, ge=30, le=180)):
|
||||
"""炸板后 1/3/5 日走势统计"""
|
||||
return limit_up.analyze_post_break_performance(days)
|
||||
|
||||
@app.get("/api/limit/reasons")
|
||||
def limit_reasons(date: Optional[str] = None):
|
||||
"""涨停原因分类(情绪/题材/业绩/政策等)"""
|
||||
d = dt.date.fromisoformat(date) if date else None
|
||||
return limit_up.classify_limit_reasons(d)
|
||||
|
||||
|
||||
# ============ 推送通知 ============
|
||||
@app.get("/api/notify/status")
|
||||
@@ -1142,6 +1601,51 @@ def delete_selector_alert(aid: int):
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ============ 模拟盘 ============
|
||||
|
||||
class PaperAccountIn(BaseModel):
|
||||
name: str
|
||||
initial_cash: float = 1_000_000.0
|
||||
|
||||
|
||||
@app.get("/api/paper/accounts")
|
||||
def paper_list_accounts():
|
||||
return {"ok": True, "accounts": paper.list_accounts()}
|
||||
|
||||
|
||||
@app.post("/api/paper/accounts")
|
||||
def paper_create_account(req: PaperAccountIn):
|
||||
return paper.create_account(req.name, req.initial_cash)
|
||||
|
||||
|
||||
@app.post("/api/paper/accounts/{account_id}/reset")
|
||||
def paper_reset_account(account_id: int, initial_cash: Optional[float] = None):
|
||||
return paper.reset_account(account_id, initial_cash)
|
||||
|
||||
|
||||
class PaperOrderIn(BaseModel):
|
||||
code: str
|
||||
side: str # buy / sell
|
||||
qty: int
|
||||
price: Optional[float] = None
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@app.post("/api/paper/accounts/{account_id}/order")
|
||||
def paper_place_order(account_id: int, req: PaperOrderIn):
|
||||
return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason)
|
||||
|
||||
|
||||
@app.get("/api/paper/accounts/{account_id}/portfolio")
|
||||
def paper_get_portfolio(account_id: int):
|
||||
return paper.get_portfolio(account_id)
|
||||
|
||||
|
||||
@app.get("/api/paper/accounts/{account_id}/trades")
|
||||
def paper_get_trades(account_id: int, limit: int = Query(100, le=500)):
|
||||
return {"ok": True, "trades": paper.get_trades(account_id, limit)}
|
||||
|
||||
|
||||
# ============ 静态前端 ============
|
||||
FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype")
|
||||
if os.path.isdir(FRONTEND_DIR):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""数据中台 ORM 模型(SQLAlchemy 2.0)。"""
|
||||
"""数据中台 ORM 模型(SQLAlchemy 2.0)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
@@ -64,6 +64,21 @@ class SectorDaily(Base):
|
||||
leader: Mapped[str] = mapped_column(String(40), default="")
|
||||
|
||||
|
||||
class SectorLeader(Base):
|
||||
"""板块每日龙头股(前5,按成交额)。"""
|
||||
__tablename__ = "sector_leaders"
|
||||
__table_args__ = (UniqueConstraint("date", "sector", "code", name="uq_sector_leader"),)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
date: Mapped[dt.date] = mapped_column(Date, index=True)
|
||||
sector: Mapped[str] = mapped_column(String(40), index=True)
|
||||
code: Mapped[str] = mapped_column(String(12))
|
||||
name: Mapped[str] = mapped_column(String(40))
|
||||
pct: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
price: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
amount: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
rank: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
||||
class FundFlowDaily(Base):
|
||||
"""行业资金流每日快照。"""
|
||||
__tablename__ = "fund_flow_daily"
|
||||
@@ -349,3 +364,68 @@ class IntradayEvent(Base):
|
||||
description: Mapped[str] = mapped_column(String(200), default="")
|
||||
detected_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now(), index=True)
|
||||
notified: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
|
||||
class PaperAccount(Base):
|
||||
"""模拟盘账户。"""
|
||||
__tablename__ = "paper_accounts"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(50), default="默认模拟盘")
|
||||
initial_cash: Mapped[float] = mapped_column(Float, default=1_000_000.0)
|
||||
cash: Mapped[float] = mapped_column(Float, default=1_000_000.0)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class PaperTrade(Base):
|
||||
"""模拟盘交易记录。"""
|
||||
__tablename__ = "paper_trades"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
account_id: Mapped[int] = mapped_column(Integer, index=True, default=1)
|
||||
date: Mapped[dt.date] = mapped_column(Date, index=True)
|
||||
code: Mapped[str] = mapped_column(String(12), index=True)
|
||||
name: Mapped[str] = mapped_column(String(40), default="")
|
||||
side: Mapped[str] = mapped_column(String(4)) # buy / sell
|
||||
price: Mapped[float] = mapped_column(Float)
|
||||
qty: Mapped[int] = mapped_column(Integer)
|
||||
fee: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
cash_before: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
cash_after: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
reason: Mapped[str] = mapped_column(String(60), default="")
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户表(用于鉴权)。"""
|
||||
__tablename__ = "users"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(100))
|
||||
is_admin: Mapped[bool] = mapped_column(default=False)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class WatchlistGroup(Base):
|
||||
"""自选股分组。"""
|
||||
__tablename__ = "watchlist_groups"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
description: Mapped[str] = mapped_column(String(200), default="")
|
||||
color: Mapped[str] = mapped_column(String(20), default="blue") # 分组颜色标识
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
is_default: Mapped[bool] = mapped_column(default=False)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class WatchlistItem(Base):
|
||||
"""自选股项目。"""
|
||||
__tablename__ = "watchlist_items"
|
||||
__table_args__ = (UniqueConstraint("group_id", "code", name="uq_watchlist_group_code"),)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(Integer, index=True)
|
||||
code: Mapped[str] = mapped_column(String(12), index=True)
|
||||
name: Mapped[str] = mapped_column(String(40), default="")
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
note: Mapped[str] = mapped_column(String(200), default="") # 个股备注
|
||||
added_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
222
backend/paper_trading.py
Normal file
222
backend/paper_trading.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""模拟盘核心逻辑。
|
||||
|
||||
- 多账户支持(默认账户 id=1)
|
||||
- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金
|
||||
- 持仓计算使用移动加权平均成本法
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from db import get_session
|
||||
from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote
|
||||
|
||||
DEFAULT_FEE_RATE = 0.0003
|
||||
|
||||
|
||||
def _get_price(code: str) -> float | None:
|
||||
with get_session() as s:
|
||||
m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none()
|
||||
if m:
|
||||
return float(m)
|
||||
row = s.execute(
|
||||
select(DailyQuote.close).where(DailyQuote.code == code)
|
||||
.order_by(DailyQuote.date.desc()).limit(1)
|
||||
).scalar_one_or_none()
|
||||
return float(row) if row else None
|
||||
|
||||
|
||||
# ── 账户管理 ──────────────────────────────────────────────
|
||||
|
||||
def ensure_default_account():
|
||||
"""确保默认账户(id=1)存在,启动时调用。"""
|
||||
with get_session() as s:
|
||||
if not s.get(PaperAccount, 1):
|
||||
s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0))
|
||||
s.commit()
|
||||
|
||||
|
||||
def list_accounts() -> list[dict]:
|
||||
with get_session() as s:
|
||||
rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all()
|
||||
return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash,
|
||||
"cash": round(r.cash, 2), "is_active": r.is_active,
|
||||
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]
|
||||
|
||||
|
||||
def create_account(name: str, initial_cash: float) -> dict:
|
||||
with get_session() as s:
|
||||
acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash)
|
||||
s.add(acc)
|
||||
s.commit()
|
||||
return {"ok": True, "id": acc.id}
|
||||
|
||||
|
||||
def reset_account(account_id: int, initial_cash: float | None = None) -> dict:
|
||||
with get_session() as s:
|
||||
acc = s.get(PaperAccount, account_id)
|
||||
if not acc:
|
||||
return {"ok": False, "msg": "账户不存在"}
|
||||
if initial_cash is not None:
|
||||
acc.initial_cash = initial_cash
|
||||
acc.cash = acc.initial_cash
|
||||
for t in s.execute(
|
||||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||||
).scalars():
|
||||
s.delete(t)
|
||||
s.commit()
|
||||
return {"ok": True, "msg": "账户已重置"}
|
||||
|
||||
|
||||
# ── 持仓计算(内部)────────────────────────────────────────
|
||||
|
||||
def _calc_holdings_in_session(account_id: int, s) -> list[dict]:
|
||||
trades = s.execute(
|
||||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||||
.order_by(PaperTrade.date, PaperTrade.id)
|
||||
).scalars().all()
|
||||
pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
|
||||
for t in trades:
|
||||
p = pos[t.code]
|
||||
p["name"] = t.name or p["name"]
|
||||
if t.side == "buy":
|
||||
p["cost"] += t.price * t.qty + t.fee
|
||||
p["qty"] += t.qty
|
||||
else:
|
||||
if p["qty"] > 0:
|
||||
avg = p["cost"] / p["qty"]
|
||||
qty = min(t.qty, p["qty"])
|
||||
p["cost"] -= avg * qty
|
||||
p["qty"] -= qty
|
||||
return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]}
|
||||
for c, v in pos.items() if v["qty"] > 0]
|
||||
|
||||
|
||||
# ── 下单 ──────────────────────────────────────────────────
|
||||
|
||||
def place_order(account_id: int, code: str, side: str, qty: int,
|
||||
price: float | None = None, reason: str = "") -> dict:
|
||||
if qty <= 0:
|
||||
return {"ok": False, "msg": "数量必须大于 0"}
|
||||
if side not in ("buy", "sell"):
|
||||
return {"ok": False, "msg": "side 只能是 buy 或 sell"}
|
||||
|
||||
exec_price = price or _get_price(code)
|
||||
if not exec_price:
|
||||
return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"}
|
||||
|
||||
fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2)
|
||||
|
||||
with get_session() as s:
|
||||
acc = s.get(PaperAccount, account_id)
|
||||
if not acc:
|
||||
return {"ok": False, "msg": "账户不存在"}
|
||||
|
||||
sec = s.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
|
||||
if side == "buy":
|
||||
cost = exec_price * qty + fee
|
||||
if acc.cash < cost:
|
||||
return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"}
|
||||
cash_before = acc.cash
|
||||
acc.cash -= cost
|
||||
else:
|
||||
holdings = _calc_holdings_in_session(account_id, s)
|
||||
pos = next((h for h in holdings if h["code"] == code), None)
|
||||
avail = pos["qty"] if pos else 0
|
||||
if avail < qty:
|
||||
return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty} 股"}
|
||||
cash_before = acc.cash
|
||||
acc.cash += exec_price * qty - fee
|
||||
|
||||
trade = PaperTrade(
|
||||
account_id=account_id,
|
||||
date=dt.date.today(),
|
||||
code=code, name=name, side=side,
|
||||
price=exec_price, qty=qty, fee=fee,
|
||||
cash_before=cash_before, cash_after=acc.cash,
|
||||
reason=reason,
|
||||
)
|
||||
s.add(trade)
|
||||
s.commit()
|
||||
return {"ok": True, "id": trade.id, "price": exec_price,
|
||||
"fee": fee, "cash_after": round(acc.cash, 2)}
|
||||
|
||||
|
||||
# ── 查询接口 ──────────────────────────────────────────────
|
||||
|
||||
def get_portfolio(account_id: int) -> dict:
|
||||
with get_session() as s:
|
||||
acc = s.get(PaperAccount, account_id)
|
||||
if not acc:
|
||||
return {"ok": False, "msg": "账户不存在"}
|
||||
cash = acc.cash
|
||||
initial = acc.initial_cash
|
||||
holdings_raw = _calc_holdings_in_session(account_id, s)
|
||||
|
||||
codes = [h["code"] for h in holdings_raw]
|
||||
px: dict[str, float] = {}
|
||||
if codes:
|
||||
with get_session() as s:
|
||||
for m in s.execute(
|
||||
select(StockMetric).where(StockMetric.code.in_(codes))
|
||||
).scalars():
|
||||
px[m.code] = m.close
|
||||
for c in [c for c in codes if c not in px]:
|
||||
row = s.execute(
|
||||
select(DailyQuote.close).where(DailyQuote.code == c)
|
||||
.order_by(DailyQuote.date.desc()).limit(1)
|
||||
).scalar_one_or_none()
|
||||
if row:
|
||||
px[c] = float(row)
|
||||
|
||||
holdings, mkt_val = [], 0.0
|
||||
for h in holdings_raw:
|
||||
avg = h["cost"] / h["qty"] if h["qty"] else 0.0
|
||||
cur = px.get(h["code"], avg)
|
||||
mv = cur * h["qty"]
|
||||
unreal = (cur - avg) * h["qty"]
|
||||
mkt_val += mv
|
||||
holdings.append({
|
||||
"code": h["code"], "name": h["name"], "qty": h["qty"],
|
||||
"avg_cost": round(avg, 3), "cur": round(cur, 3),
|
||||
"market_value": round(mv, 2),
|
||||
"unrealized": round(unreal, 2),
|
||||
"unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0,
|
||||
})
|
||||
holdings.sort(key=lambda x: x["unrealized"], reverse=True)
|
||||
|
||||
total_assets = cash + mkt_val
|
||||
total_pnl = total_assets - initial
|
||||
return {
|
||||
"ok": True,
|
||||
"account_id": account_id,
|
||||
"summary": {
|
||||
"initial_cash": round(initial, 2),
|
||||
"cash": round(cash, 2),
|
||||
"market_value": round(mkt_val, 2),
|
||||
"total_assets": round(total_assets, 2),
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0,
|
||||
"positions": len(holdings),
|
||||
},
|
||||
"holdings": holdings,
|
||||
}
|
||||
|
||||
|
||||
def get_trades(account_id: int, limit: int = 100) -> list[dict]:
|
||||
with get_session() as s:
|
||||
rows = s.execute(
|
||||
select(PaperTrade).where(PaperTrade.account_id == account_id)
|
||||
.order_by(PaperTrade.id.desc()).limit(limit)
|
||||
).scalars().all()
|
||||
return [{
|
||||
"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
|
||||
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
|
||||
"cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2),
|
||||
"reason": t.reason,
|
||||
} for t in rows]
|
||||
347
backend/position_cost.py
Normal file
347
backend/position_cost.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""持仓成本可视化增强"""
|
||||
import datetime as dt
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, func
|
||||
from db import get_session
|
||||
from models import Trade, DailyQuote, StockMetric
|
||||
|
||||
# A股交易成本配置
|
||||
COST_CONFIG = {
|
||||
"stamp_tax": 0.001, # 印花税 0.1%(仅卖出)
|
||||
"commission_rate": 0.0003, # 佣金费率 0.03%
|
||||
"commission_min": 5.0, # 最低佣金 5元
|
||||
"transfer_fee": 0.00001, # 过户费 0.001%(沪市)
|
||||
}
|
||||
|
||||
def calculate_trade_cost(price: float, qty: int, side: str, is_sh: bool = True) -> Dict:
|
||||
"""精确计算交易成本
|
||||
|
||||
Args:
|
||||
price: 成交价格
|
||||
qty: 成交数量
|
||||
side: buy/sell
|
||||
is_sh: 是否沪市(影响过户费)
|
||||
|
||||
Returns:
|
||||
成本明细字典
|
||||
"""
|
||||
amount = price * qty
|
||||
|
||||
# 佣金(买卖都有)
|
||||
commission = max(amount * COST_CONFIG["commission_rate"], COST_CONFIG["commission_min"])
|
||||
|
||||
# 印花税(仅卖出)
|
||||
stamp_tax = amount * COST_CONFIG["stamp_tax"] if side == "sell" else 0.0
|
||||
|
||||
# 过户费(沪市买卖都有,深市无)
|
||||
transfer_fee = amount * COST_CONFIG["transfer_fee"] if is_sh else 0.0
|
||||
|
||||
total_cost = commission + stamp_tax + transfer_fee
|
||||
|
||||
return {
|
||||
"amount": round(amount, 2),
|
||||
"commission": round(commission, 2),
|
||||
"stamp_tax": round(stamp_tax, 2),
|
||||
"transfer_fee": round(transfer_fee, 2),
|
||||
"total_cost": round(total_cost, 2),
|
||||
"cost_rate": round(total_cost / amount * 100, 4) if amount > 0 else 0.0
|
||||
}
|
||||
|
||||
def get_position_cost_lines(code: str) -> Dict:
|
||||
"""获取个股的持仓成本线数据(用于K线图标注)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"code": "600519",
|
||||
"name": "贵州茅台",
|
||||
"current_position": {
|
||||
"qty": 100,
|
||||
"avg_cost": 1680.5,
|
||||
"total_cost": 168050.0,
|
||||
"trades_count": 3
|
||||
},
|
||||
"cost_history": [
|
||||
{"date": "2024-01-15", "cost": 1650.0, "qty": 100, "action": "买入"},
|
||||
{"date": "2024-02-10", "cost": 1680.5, "qty": 100, "action": "补仓"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
with get_session() as s:
|
||||
trades = s.execute(
|
||||
select(Trade).where(Trade.code == code)
|
||||
.order_by(Trade.date, Trade.id)
|
||||
).scalars().all()
|
||||
|
||||
if not trades:
|
||||
return {"ok": False, "msg": "该股票无交易记录"}
|
||||
|
||||
# 计算持仓成本变化
|
||||
qty = 0
|
||||
cost = 0.0
|
||||
cost_history = []
|
||||
|
||||
for t in trades:
|
||||
is_sh = t.code.startswith("6")
|
||||
|
||||
if t.side == "buy":
|
||||
# 买入:加权平均成本
|
||||
old_qty = qty
|
||||
old_cost = cost
|
||||
|
||||
qty += t.qty
|
||||
cost += t.price * t.qty + t.fee
|
||||
|
||||
avg_cost = cost / qty if qty > 0 else 0
|
||||
action = "补仓" if old_qty > 0 else "买入"
|
||||
|
||||
cost_history.append({
|
||||
"date": t.date.isoformat(),
|
||||
"cost": round(avg_cost, 2),
|
||||
"qty": qty,
|
||||
"action": action,
|
||||
"trade_price": t.price,
|
||||
"trade_qty": t.qty
|
||||
})
|
||||
|
||||
else: # sell
|
||||
if qty <= 0:
|
||||
continue
|
||||
|
||||
avg_cost = cost / qty
|
||||
sell_qty = min(t.qty, qty)
|
||||
|
||||
# 卖出:减少持仓
|
||||
cost -= avg_cost * sell_qty
|
||||
qty -= sell_qty
|
||||
|
||||
action = "清仓" if qty == 0 else "减仓"
|
||||
|
||||
cost_history.append({
|
||||
"date": t.date.isoformat(),
|
||||
"cost": round(cost / qty, 2) if qty > 0 else 0,
|
||||
"qty": qty,
|
||||
"action": action,
|
||||
"trade_price": t.price,
|
||||
"trade_qty": sell_qty,
|
||||
"pnl": round((t.price - avg_cost) * sell_qty - t.fee, 2)
|
||||
})
|
||||
|
||||
# 当前持仓
|
||||
current_position = None
|
||||
if qty > 0:
|
||||
avg_cost = cost / qty
|
||||
|
||||
# 获取当前价格
|
||||
metric = s.execute(
|
||||
select(StockMetric).where(StockMetric.code == code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
current_price = metric.close if metric else avg_cost
|
||||
|
||||
current_position = {
|
||||
"qty": qty,
|
||||
"avg_cost": round(avg_cost, 2),
|
||||
"total_cost": round(cost, 2),
|
||||
"current_price": round(current_price, 2),
|
||||
"market_value": round(current_price * qty, 2),
|
||||
"unrealized_pnl": round((current_price - avg_cost) * qty, 2),
|
||||
"unrealized_pct": round((current_price / avg_cost - 1) * 100, 2) if avg_cost > 0 else 0,
|
||||
"trades_count": len([t for t in trades if t.side == "buy"])
|
||||
}
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"code": code,
|
||||
"name": trades[0].name,
|
||||
"current_position": current_position,
|
||||
"cost_history": cost_history
|
||||
}
|
||||
|
||||
def get_position_cost_distribution() -> Dict:
|
||||
"""获取所有持仓的成本分布(盈亏区间图)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"profitable": [...], # 盈利持仓
|
||||
"unprofitable": [...], # 亏损持仓
|
||||
"breakeven": [...] # 持平持仓
|
||||
}
|
||||
"""
|
||||
with get_session() as s:
|
||||
trades = s.execute(
|
||||
select(Trade).order_by(Trade.date, Trade.id)
|
||||
).scalars().all()
|
||||
|
||||
# 计算当前持仓
|
||||
pos = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
|
||||
|
||||
for t in trades:
|
||||
p = pos[t.code]
|
||||
p["name"] = t.name or p["name"]
|
||||
|
||||
if t.side == "buy":
|
||||
p["cost"] += t.price * t.qty + t.fee
|
||||
p["qty"] += t.qty
|
||||
else:
|
||||
if p["qty"] > 0:
|
||||
avg = p["cost"] / p["qty"]
|
||||
qty = min(t.qty, p["qty"])
|
||||
p["cost"] -= avg * qty
|
||||
p["qty"] -= qty
|
||||
|
||||
# 获取当前价格
|
||||
codes = [c for c, v in pos.items() if v["qty"] > 0]
|
||||
if not codes:
|
||||
return {"ok": True, "profitable": [], "unprofitable": [], "breakeven": []}
|
||||
|
||||
metrics = s.execute(
|
||||
select(StockMetric).where(StockMetric.code.in_(codes))
|
||||
).scalars().all()
|
||||
|
||||
price_map = {m.code: m.close for m in metrics}
|
||||
|
||||
# 分类统计
|
||||
profitable = []
|
||||
unprofitable = []
|
||||
breakeven = []
|
||||
|
||||
for code, p in pos.items():
|
||||
if p["qty"] <= 0:
|
||||
continue
|
||||
|
||||
avg_cost = p["cost"] / p["qty"]
|
||||
current_price = price_map.get(code, avg_cost)
|
||||
unrealized = (current_price - avg_cost) * p["qty"]
|
||||
unrealized_pct = (current_price / avg_cost - 1) * 100 if avg_cost > 0 else 0
|
||||
|
||||
item = {
|
||||
"code": code,
|
||||
"name": p["name"],
|
||||
"qty": p["qty"],
|
||||
"avg_cost": round(avg_cost, 2),
|
||||
"current_price": round(current_price, 2),
|
||||
"market_value": round(current_price * p["qty"], 2),
|
||||
"cost_value": round(p["cost"], 2),
|
||||
"unrealized": round(unrealized, 2),
|
||||
"unrealized_pct": round(unrealized_pct, 2)
|
||||
}
|
||||
|
||||
if unrealized_pct > 0.5:
|
||||
profitable.append(item)
|
||||
elif unrealized_pct < -0.5:
|
||||
unprofitable.append(item)
|
||||
else:
|
||||
breakeven.append(item)
|
||||
|
||||
# 排序
|
||||
profitable.sort(key=lambda x: x["unrealized"], reverse=True)
|
||||
unprofitable.sort(key=lambda x: x["unrealized"])
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"profitable": profitable,
|
||||
"unprofitable": unprofitable,
|
||||
"breakeven": breakeven,
|
||||
"summary": {
|
||||
"total_positions": len(codes),
|
||||
"profitable_count": len(profitable),
|
||||
"unprofitable_count": len(unprofitable),
|
||||
"breakeven_count": len(breakeven),
|
||||
"win_rate": round(len(profitable) / len(codes) * 100, 1) if codes else 0
|
||||
}
|
||||
}
|
||||
|
||||
def estimate_trade_cost(code: str, price: float, qty: int, side: str) -> Dict:
|
||||
"""估算交易成本(下单前预估)
|
||||
|
||||
Args:
|
||||
code: 股票代码
|
||||
price: 预计成交价
|
||||
qty: 交易数量
|
||||
side: buy/sell
|
||||
|
||||
Returns:
|
||||
成本明细和净值
|
||||
"""
|
||||
is_sh = code.startswith("6")
|
||||
cost_detail = calculate_trade_cost(price, qty, side, is_sh)
|
||||
|
||||
if side == "buy":
|
||||
net_amount = cost_detail["amount"] + cost_detail["total_cost"]
|
||||
msg = f"买入需支付: {round(net_amount, 2)} 元(含交易成本 {cost_detail['total_cost']} 元)"
|
||||
else:
|
||||
net_amount = cost_detail["amount"] - cost_detail["total_cost"]
|
||||
msg = f"卖出可获得: {round(net_amount, 2)} 元(扣除交易成本 {cost_detail['total_cost']} 元)"
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"code": code,
|
||||
"price": price,
|
||||
"qty": qty,
|
||||
"side": side,
|
||||
"cost_detail": cost_detail,
|
||||
"net_amount": round(net_amount, 2),
|
||||
"message": msg
|
||||
}
|
||||
|
||||
def get_cost_breakdown_for_position(code: str) -> Dict:
|
||||
"""获取持仓的详细成本拆解
|
||||
|
||||
Returns:
|
||||
{
|
||||
"total_cost": 168500.0,
|
||||
"purchase_amount": 168050.0, # 实际买入金额
|
||||
"commission": 350.0, # 累计佣金
|
||||
"stamp_tax": 0.0, # 累计印花税(买入无)
|
||||
"transfer_fee": 100.0, # 累计过户费
|
||||
"trades": [...] # 每笔交易明细
|
||||
}
|
||||
"""
|
||||
with get_session() as s:
|
||||
trades = s.execute(
|
||||
select(Trade).where(Trade.code == code, Trade.side == "buy")
|
||||
.order_by(Trade.date)
|
||||
).scalars().all()
|
||||
|
||||
if not trades:
|
||||
return {"ok": False, "msg": "该股票无买入记录"}
|
||||
|
||||
is_sh = code.startswith("6")
|
||||
|
||||
total_purchase = 0.0
|
||||
total_commission = 0.0
|
||||
total_stamp = 0.0
|
||||
total_transfer = 0.0
|
||||
trade_details = []
|
||||
|
||||
for t in trades:
|
||||
cost = calculate_trade_cost(t.price, t.qty, "buy", is_sh)
|
||||
|
||||
total_purchase += cost["amount"]
|
||||
total_commission += cost["commission"]
|
||||
total_stamp += cost["stamp_tax"]
|
||||
total_transfer += cost["transfer_fee"]
|
||||
|
||||
trade_details.append({
|
||||
"date": t.date.isoformat(),
|
||||
"price": t.price,
|
||||
"qty": t.qty,
|
||||
"amount": cost["amount"],
|
||||
"cost_detail": cost
|
||||
})
|
||||
|
||||
total_cost = total_purchase + total_commission + total_stamp + total_transfer
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"code": code,
|
||||
"name": trades[0].name,
|
||||
"total_cost": round(total_cost, 2),
|
||||
"purchase_amount": round(total_purchase, 2),
|
||||
"commission": round(total_commission, 2),
|
||||
"stamp_tax": round(total_stamp, 2),
|
||||
"transfer_fee": round(total_transfer, 2),
|
||||
"cost_rate": round((total_cost - total_purchase) / total_purchase * 100, 4) if total_purchase > 0 else 0,
|
||||
"trades": trade_details
|
||||
}
|
||||
88
backend/redis_cache.py
Normal file
88
backend/redis_cache.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Redis 缓存层,替代内存缓存,支持持久化和跨进程共享。"""
|
||||
import json
|
||||
import redis
|
||||
from typing import Any, Optional
|
||||
from config import REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self):
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self.enabled = False
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
"""连接 Redis,失败时降级为禁用状态"""
|
||||
try:
|
||||
self.client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD if REDIS_PASSWORD else None,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=2,
|
||||
socket_timeout=2
|
||||
)
|
||||
self.client.ping()
|
||||
self.enabled = True
|
||||
print(f"✓ Redis 已连接: {REDIS_HOST}:{REDIS_PORT}")
|
||||
except Exception as e:
|
||||
self.enabled = False
|
||||
print(f"✗ Redis 连接失败,缓存已禁用: {e}")
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存,自动反序列化 JSON"""
|
||||
if not self.enabled:
|
||||
return None
|
||||
try:
|
||||
value = self.client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Redis get error: {e}")
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, expire: int = 3600):
|
||||
"""设置缓存,自动序列化为 JSON
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值(可序列化为JSON的对象)
|
||||
expire: 过期时间(秒),默认1小时
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
try:
|
||||
serialized = json.dumps(value, ensure_ascii=False, default=str)
|
||||
self.client.setex(key, expire, serialized)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Redis set error: {e}")
|
||||
return False
|
||||
|
||||
def delete(self, key: str):
|
||||
"""删除缓存"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
try:
|
||||
self.client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Redis delete error: {e}")
|
||||
return False
|
||||
|
||||
def clear_pattern(self, pattern: str):
|
||||
"""批量删除匹配模式的键"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
try:
|
||||
keys = self.client.keys(pattern)
|
||||
if keys:
|
||||
return self.client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Redis clear_pattern error: {e}")
|
||||
return 0
|
||||
|
||||
# 全局单例
|
||||
cache = RedisCache()
|
||||
@@ -8,3 +8,7 @@ APScheduler>=3.10.4
|
||||
psycopg2-binary>=2.9.9
|
||||
jieba>=0.42.1
|
||||
numpy>=1.26.0
|
||||
redis>=5.0.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
passlib[bcrypt]>=1.7.4
|
||||
python-multipart>=0.0.9
|
||||
|
||||
@@ -13,6 +13,7 @@ import alerts
|
||||
import report
|
||||
import signals
|
||||
import intraday_radar
|
||||
import trade_calendar
|
||||
|
||||
_scheduler: BackgroundScheduler | None = None
|
||||
_lock = threading.Lock()
|
||||
@@ -134,11 +135,16 @@ def start_scheduler():
|
||||
_job_signal_stats, CronTrigger(day_of_week="sat", hour=9, minute=0),
|
||||
id="signal_stats", replace_existing=True, misfire_grace_time=7200,
|
||||
)
|
||||
# 盘中异动扫描(交易时间每分钟)
|
||||
# 盘中异动扫描(交易时间每分钟)
|
||||
_scheduler.add_job(
|
||||
_safe_scan_intraday, IntervalTrigger(seconds=60),
|
||||
id="intraday_scan", replace_existing=True, max_instances=1,
|
||||
)
|
||||
# 每日早盘前推送日历事件提醒(持仓股除权、解禁、财报等)
|
||||
_scheduler.add_job(
|
||||
_job_calendar_alerts, CronTrigger(day_of_week="mon-fri", hour=8, minute=30),
|
||||
id="calendar_alerts", replace_existing=True, misfire_grace_time=3600,
|
||||
)
|
||||
_scheduler.start()
|
||||
return _scheduler
|
||||
|
||||
@@ -154,7 +160,13 @@ def _safe_scan_intraday():
|
||||
try:
|
||||
result = intraday_radar.scan_all()
|
||||
if result.get("count", 0) > 0:
|
||||
# 有新异动时自动推送
|
||||
intraday_radar.notify_events()
|
||||
except Exception as e:
|
||||
print("[intraday] scan error:", repr(e)[:120])
|
||||
|
||||
|
||||
def _job_calendar_alerts():
|
||||
try:
|
||||
trade_calendar.check_and_push_calendar_alerts()
|
||||
except Exception as e:
|
||||
print("[calendar] alert error:", repr(e)[:120])
|
||||
|
||||
179
backend/test_core_features.py
Normal file
179
backend/test_core_features.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""测试三大核心功能的脚本"""
|
||||
import requests
|
||||
import time
|
||||
import sys
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
|
||||
def test_health():
|
||||
"""测试健康检查接口"""
|
||||
print("\n=== 1. 测试健康检查 ===")
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api/health")
|
||||
data = resp.json()
|
||||
print(f"状态码: {resp.status_code}")
|
||||
print(f"响应: {data}")
|
||||
print(f"✓ AkShare: {'可用' if data.get('akshare') else '不可用'}")
|
||||
print(f"✓ Redis: {'已连接' if data.get('redis') else '未连接(将使用内存缓存)'}")
|
||||
print(f"✓ 鉴权: {'已启用' if data.get('auth') else '未启用'}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
return False
|
||||
|
||||
def test_cache_performance():
|
||||
"""测试 Redis 缓存性能"""
|
||||
print("\n=== 2. 测试 Redis 缓存性能 ===")
|
||||
try:
|
||||
# 第一次请求(缓存未命中)
|
||||
start = time.time()
|
||||
resp1 = requests.get(f"{BASE_URL}/api/indices")
|
||||
time1 = time.time() - start
|
||||
print(f"第一次请求耗时: {time1:.3f}秒")
|
||||
|
||||
# 第二次请求(应该命中缓存)
|
||||
start = time.time()
|
||||
resp2 = requests.get(f"{BASE_URL}/api/indices")
|
||||
time2 = time.time() - start
|
||||
print(f"第二次请求耗时: {time2:.3f}秒")
|
||||
|
||||
speedup = time1 / time2 if time2 > 0 else 1
|
||||
print(f"性能提升: {speedup:.1f}x")
|
||||
|
||||
if speedup > 2:
|
||||
print("✓ Redis 缓存生效")
|
||||
else:
|
||||
print("⚠ 缓存可能未生效或使用内存缓存")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
return False
|
||||
|
||||
def test_auth_login():
|
||||
"""测试登录功能"""
|
||||
print("\n=== 3. 测试认证系统 - 登录 ===")
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/api/auth/login",
|
||||
json={"username": "admin", "password": "admin123"}
|
||||
)
|
||||
print(f"状态码: {resp.status_code}")
|
||||
data = resp.json()
|
||||
|
||||
if resp.status_code == 200 and data.get("access_token"):
|
||||
print(f"✓ 登录成功")
|
||||
print(f" 用户名: {data.get('username')}")
|
||||
print(f" 管理员: {data.get('is_admin')}")
|
||||
print(f" Token: {data.get('access_token')[:50]}...")
|
||||
return data.get("access_token")
|
||||
else:
|
||||
print(f"✗ 登录失败: {data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
return None
|
||||
|
||||
def test_auth_protected(token):
|
||||
"""测试受保护的接口"""
|
||||
print("\n=== 4. 测试认证系统 - 受保护接口 ===")
|
||||
|
||||
# 测试无 Token 访问
|
||||
print("\n4.1 无 Token 访问管理接口:")
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api/admin/status")
|
||||
print(f"状态码: {resp.status_code}")
|
||||
if resp.status_code == 401:
|
||||
print("✓ 正确拦截未认证请求")
|
||||
else:
|
||||
print(f"⚠ 预期 401,实际 {resp.status_code}")
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
|
||||
# 测试有 Token 访问
|
||||
print("\n4.2 使用 Token 访问管理接口:")
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{BASE_URL}/api/admin/status",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
print(f"状态码: {resp.status_code}")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
print("✓ Token 认证成功")
|
||||
print(f" 数据库记录数: {data.get('counts', {})}")
|
||||
return True
|
||||
else:
|
||||
print(f"✗ 认证失败: {resp.json()}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
return False
|
||||
|
||||
def test_exception_handling():
|
||||
"""测试异常处理"""
|
||||
print("\n=== 5. 测试统一异常处理 ===")
|
||||
|
||||
# 测试参数验证错误
|
||||
print("\n5.1 测试参数验证错误:")
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api/kline?days=invalid")
|
||||
print(f"状态码: {resp.status_code}")
|
||||
data = resp.json()
|
||||
if resp.status_code == 422:
|
||||
print("✓ 参数验证错误处理正确")
|
||||
print(f" 错误信息: {data.get('error')}")
|
||||
else:
|
||||
print(f"⚠ 预期 422,实际 {resp.status_code}")
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
|
||||
# 测试业务逻辑错误
|
||||
print("\n5.2 测试业务逻辑错误:")
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api/backtest?symbol=600519&fast=20&slow=10")
|
||||
print(f"状态码: {resp.status_code}")
|
||||
data = resp.json()
|
||||
if not data.get('ok'):
|
||||
print("✓ 业务错误处理正确")
|
||||
print(f" 错误信息: {data.get('msg')}")
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
print("="*50)
|
||||
print("三大核心功能测试")
|
||||
print("="*50)
|
||||
|
||||
# 检查服务是否运行
|
||||
try:
|
||||
requests.get(f"{BASE_URL}/api/health", timeout=2)
|
||||
except:
|
||||
print(f"\n✗ 无法连接到服务: {BASE_URL}")
|
||||
print("请确保服务已启动: python main.py")
|
||||
sys.exit(1)
|
||||
|
||||
# 运行测试
|
||||
test_health()
|
||||
test_cache_performance()
|
||||
token = test_auth_login()
|
||||
|
||||
if token:
|
||||
test_auth_protected(token)
|
||||
else:
|
||||
print("\n⚠ 跳过受保护接口测试(登录失败)")
|
||||
|
||||
test_exception_handling()
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("测试完成!")
|
||||
print("="*50)
|
||||
print("\n下一步:")
|
||||
print("1. 如果 Redis 显示'未连接',请安装并启动 Redis")
|
||||
print("2. 如果登录失败,请运行: python cli.py init")
|
||||
print("3. 登录成功后,务必修改默认密码")
|
||||
print("4. 生产环境请修改 .env 中的 SECRET_KEY")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
338
backend/trade_calendar.py
Normal file
338
backend/trade_calendar.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""交易日历与关键事件提醒"""
|
||||
import datetime as dt
|
||||
from typing import List, Optional, Dict
|
||||
from sqlalchemy import select, and_
|
||||
from db import get_session
|
||||
from models import Trade, Security, CorporateEvent, AlertEvent
|
||||
import akshare_service as svc
|
||||
|
||||
try:
|
||||
import akshare as ak
|
||||
AK_OK = True
|
||||
except Exception:
|
||||
ak = None
|
||||
AK_OK = False
|
||||
|
||||
|
||||
def get_upcoming_dividends(days_ahead: int = 30) -> Dict:
|
||||
"""获取即将到来的除权除息日"""
|
||||
today = dt.date.today()
|
||||
end = today + dt.timedelta(days=days_ahead)
|
||||
|
||||
# 获取持仓股票代码
|
||||
with get_session() as s:
|
||||
trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all()
|
||||
|
||||
# 计算当前持仓
|
||||
pos = {}
|
||||
for t in trades:
|
||||
if t.code not in pos:
|
||||
pos[t.code] = {"qty": 0, "name": t.name}
|
||||
if t.side == "buy":
|
||||
pos[t.code]["qty"] += t.qty
|
||||
else:
|
||||
pos[t.code]["qty"] = max(0, pos[t.code]["qty"] - t.qty)
|
||||
|
||||
holding_codes = [c for c, v in pos.items() if v["qty"] > 0]
|
||||
|
||||
events = []
|
||||
if AK_OK and holding_codes:
|
||||
try:
|
||||
df = ak.stock_zh_a_dividend()
|
||||
if df is not None and not df.empty:
|
||||
for _, r in df.iterrows():
|
||||
code = str(r.get("代码", ""))
|
||||
if code not in holding_codes:
|
||||
continue
|
||||
ex_date_str = str(r.get("除权除息日", ""))
|
||||
if not ex_date_str or ex_date_str == "nan":
|
||||
continue
|
||||
try:
|
||||
ex_date = dt.date.fromisoformat(ex_date_str[:10])
|
||||
if today <= ex_date <= end:
|
||||
days_left = (ex_date - today).days
|
||||
events.append({
|
||||
"code": code,
|
||||
"name": pos[code]["name"],
|
||||
"event_type": "除权除息",
|
||||
"event_date": ex_date.isoformat(),
|
||||
"days_left": days_left,
|
||||
"detail": f"送股: {r.get('送股', 0)}, 转增: {r.get('转增', 0)}, 派息: {r.get('派息', 0)}",
|
||||
"is_holding": True,
|
||||
"urgency": "high" if days_left <= 3 else "medium"
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 补充从数据库中获取的事件
|
||||
with get_session() as s:
|
||||
db_events = s.execute(
|
||||
select(CorporateEvent).where(
|
||||
and_(
|
||||
CorporateEvent.event_type == "dividend",
|
||||
CorporateEvent.event_date >= today,
|
||||
CorporateEvent.event_date <= end
|
||||
)
|
||||
).order_by(CorporateEvent.event_date)
|
||||
).scalars().all()
|
||||
|
||||
for e in db_events:
|
||||
if any(ev["code"] == e.code for ev in events):
|
||||
continue
|
||||
days_left = (e.event_date - today).days
|
||||
events.append({
|
||||
"code": e.code,
|
||||
"name": e.name,
|
||||
"event_type": "除权除息",
|
||||
"event_date": e.event_date.isoformat(),
|
||||
"days_left": days_left,
|
||||
"detail": e.description,
|
||||
"is_holding": e.code in holding_codes,
|
||||
"urgency": "high" if days_left <= 3 else "medium"
|
||||
})
|
||||
|
||||
events.sort(key=lambda x: x["event_date"])
|
||||
return {"ok": True, "events": events, "count": len(events)}
|
||||
|
||||
|
||||
def get_unlock_calendar(days_ahead: int = 90) -> Dict:
|
||||
"""获取限售解禁日历"""
|
||||
today = dt.date.today()
|
||||
end = today + dt.timedelta(days=days_ahead)
|
||||
|
||||
events = []
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_restricted_release_summary_em()
|
||||
if df is not None and not df.empty:
|
||||
for _, r in df.head(50).iterrows():
|
||||
date_str = str(r.get("解禁日期", ""))
|
||||
if not date_str or date_str == "nan":
|
||||
continue
|
||||
try:
|
||||
unlock_date = dt.date.fromisoformat(date_str[:10])
|
||||
if today <= unlock_date <= end:
|
||||
amount = float(r.get("解禁数量", 0) or 0)
|
||||
market_val = float(r.get("解禁市值", 0) or 0)
|
||||
events.append({
|
||||
"code": str(r.get("代码", "")),
|
||||
"name": str(r.get("名称", "")),
|
||||
"event_type": "限售解禁",
|
||||
"event_date": unlock_date.isoformat(),
|
||||
"days_left": (unlock_date - today).days,
|
||||
"detail": f"解禁市值: {round(market_val/1e8, 2)}亿",
|
||||
"amount_billion": round(market_val / 1e8, 2),
|
||||
"urgency": "high" if market_val >= 10e8 else "medium"
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 从数据库补充
|
||||
with get_session() as s:
|
||||
db_events = s.execute(
|
||||
select(CorporateEvent).where(
|
||||
and_(
|
||||
CorporateEvent.event_type == "unlock",
|
||||
CorporateEvent.event_date >= today,
|
||||
CorporateEvent.event_date <= end
|
||||
)
|
||||
).order_by(CorporateEvent.event_date)
|
||||
).scalars().all()
|
||||
|
||||
for e in db_events:
|
||||
if any(ev["code"] == e.code for ev in events):
|
||||
continue
|
||||
events.append({
|
||||
"code": e.code,
|
||||
"name": e.name,
|
||||
"event_type": "限售解禁",
|
||||
"event_date": e.event_date.isoformat(),
|
||||
"days_left": (e.event_date - today).days,
|
||||
"detail": e.description,
|
||||
"amount_billion": e.amount,
|
||||
"urgency": "high" if e.amount >= 10 else "medium"
|
||||
})
|
||||
|
||||
events.sort(key=lambda x: x["event_date"])
|
||||
return {"ok": True, "events": events, "count": len(events)}
|
||||
|
||||
|
||||
def get_earnings_calendar(days_ahead: int = 30, holding_only: bool = False) -> Dict:
|
||||
"""获取财报披露日历"""
|
||||
today = dt.date.today()
|
||||
end = today + dt.timedelta(days=days_ahead)
|
||||
|
||||
# 获取持仓代码
|
||||
holding_codes = set()
|
||||
if holding_only:
|
||||
with get_session() as s:
|
||||
trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all()
|
||||
pos = {}
|
||||
for t in trades:
|
||||
if t.code not in pos:
|
||||
pos[t.code] = 0
|
||||
pos[t.code] += t.qty if t.side == "buy" else -t.qty
|
||||
holding_codes = {c for c, q in pos.items() if q > 0}
|
||||
|
||||
events = []
|
||||
if AK_OK:
|
||||
try:
|
||||
df = ak.stock_notice_report()
|
||||
if df is not None and not df.empty:
|
||||
for _, r in df.head(100).iterrows():
|
||||
date_str = str(r.get("公告日期", ""))
|
||||
code = str(r.get("代码", ""))
|
||||
if not date_str or date_str == "nan":
|
||||
continue
|
||||
if holding_only and code not in holding_codes:
|
||||
continue
|
||||
try:
|
||||
report_date = dt.date.fromisoformat(date_str[:10])
|
||||
if today <= report_date <= end:
|
||||
events.append({
|
||||
"code": code,
|
||||
"name": str(r.get("名称", "")),
|
||||
"event_type": "财报披露",
|
||||
"event_date": report_date.isoformat(),
|
||||
"days_left": (report_date - today).days,
|
||||
"report_type": str(r.get("公告类型", "")),
|
||||
"is_holding": code in holding_codes,
|
||||
"urgency": "high" if code in holding_codes else "low"
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 从数据库补充
|
||||
with get_session() as s:
|
||||
db_events = s.execute(
|
||||
select(CorporateEvent).where(
|
||||
and_(
|
||||
CorporateEvent.event_type == "earnings",
|
||||
CorporateEvent.event_date >= today,
|
||||
CorporateEvent.event_date <= end
|
||||
)
|
||||
).order_by(CorporateEvent.event_date)
|
||||
).scalars().all()
|
||||
|
||||
for e in db_events:
|
||||
if any(ev["code"] == e.code for ev in events):
|
||||
continue
|
||||
if holding_only and e.code not in holding_codes:
|
||||
continue
|
||||
events.append({
|
||||
"code": e.code,
|
||||
"name": e.name,
|
||||
"event_type": "财报披露",
|
||||
"event_date": e.event_date.isoformat(),
|
||||
"days_left": (e.event_date - today).days,
|
||||
"report_type": e.title,
|
||||
"is_holding": e.code in holding_codes,
|
||||
"urgency": "high" if e.code in holding_codes else "low"
|
||||
})
|
||||
|
||||
events.sort(key=lambda x: x["event_date"])
|
||||
return {"ok": True, "events": events, "count": len(events)}
|
||||
|
||||
|
||||
def get_all_upcoming_events(days_ahead: int = 30) -> Dict:
|
||||
"""获取所有即将到来的关键事件(综合视图)"""
|
||||
today = dt.date.today()
|
||||
all_events = []
|
||||
|
||||
# 合并所有事件
|
||||
for result in [
|
||||
get_upcoming_dividends(days_ahead),
|
||||
get_earnings_calendar(days_ahead),
|
||||
get_unlock_calendar(days_ahead)
|
||||
]:
|
||||
all_events.extend(result.get("events", []))
|
||||
|
||||
# 按日期排序
|
||||
all_events.sort(key=lambda x: x["event_date"])
|
||||
|
||||
# 按日期分组
|
||||
grouped = {}
|
||||
for event in all_events:
|
||||
date = event["event_date"]
|
||||
if date not in grouped:
|
||||
grouped[date] = []
|
||||
grouped[date].append(event)
|
||||
|
||||
# 生成日历视图
|
||||
calendar = []
|
||||
for date_str, events in sorted(grouped.items()):
|
||||
date = dt.date.fromisoformat(date_str)
|
||||
calendar.append({
|
||||
"date": date_str,
|
||||
"weekday": ["周一", "周二", "周三", "周四", "周五", "周六", "周日"][date.weekday()],
|
||||
"days_left": (date - today).days,
|
||||
"events": events,
|
||||
"has_high_urgency": any(e["urgency"] == "high" for e in events),
|
||||
"has_holding": any(e.get("is_holding", False) for e in events)
|
||||
})
|
||||
|
||||
# 紧急事件(3天内)
|
||||
urgent = [e for e in all_events if e.get("days_left", 99) <= 3]
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"calendar": calendar,
|
||||
"urgent": urgent,
|
||||
"total": len(all_events),
|
||||
"summary": {
|
||||
"dividends": len([e for e in all_events if e["event_type"] == "除权除息"]),
|
||||
"earnings": len([e for e in all_events if e["event_type"] == "财报披露"]),
|
||||
"unlocks": len([e for e in all_events if e["event_type"] == "限售解禁"]),
|
||||
"urgent": len(urgent)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def check_and_push_calendar_alerts() -> Dict:
|
||||
"""检查并推送日历事件预警(定时任务调用)"""
|
||||
try:
|
||||
from notifier import notify
|
||||
except Exception:
|
||||
return {"ok": False, "msg": "推送模块不可用"}
|
||||
|
||||
result = get_all_upcoming_events(days_ahead=7)
|
||||
urgent = result.get("urgent", [])
|
||||
|
||||
if not urgent:
|
||||
return {"ok": True, "msg": "无紧急事件", "pushed": 0}
|
||||
|
||||
# 生成推送内容
|
||||
lines = [f"📅 未来7天关键事件提醒({len(urgent)}条)\n"]
|
||||
for event in urgent[:10]: # 最多推送10条
|
||||
urgency_icon = "🔴" if event["urgency"] == "high" else "🟡"
|
||||
holding_icon = "💰" if event.get("is_holding") else ""
|
||||
lines.append(
|
||||
f"{urgency_icon}{holding_icon} {event['event_date']} "
|
||||
f"{event['name']}({event['code']}) "
|
||||
f"{event['event_type']} "
|
||||
f"({event['days_left']}天后)"
|
||||
)
|
||||
|
||||
message = "\n".join(lines)
|
||||
notify("【Blackdata】关键事件提醒", message)
|
||||
|
||||
# 写入站内通知
|
||||
with get_session() as s:
|
||||
for event in urgent[:10]:
|
||||
alert = AlertEvent(
|
||||
rule_id=0,
|
||||
code=event["code"],
|
||||
name=event["name"],
|
||||
message=f"{event['event_type']}: {event.get('detail', '')} ({event['days_left']}天后)",
|
||||
value=event.get("amount_billion", 0)
|
||||
)
|
||||
s.add(alert)
|
||||
s.commit()
|
||||
|
||||
return {"ok": True, "pushed": len(urgent), "msg": f"已推送 {len(urgent)} 条事件提醒"}
|
||||
345
backend/watchlist_manager.py
Normal file
345
backend/watchlist_manager.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""自选股分组管理"""
|
||||
import datetime as dt
|
||||
from typing import List, Dict, Optional
|
||||
from sqlalchemy import select, func
|
||||
from db import get_session
|
||||
from models import WatchlistGroup, WatchlistItem, Security
|
||||
import akshare_service as svc
|
||||
|
||||
# 预设分组
|
||||
DEFAULT_GROUPS = [
|
||||
{"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True},
|
||||
{"name": "观察池", "description": "待观察的潜力股", "color": "blue"},
|
||||
{"name": "持仓股", "description": "当前持仓的股票", "color": "green"},
|
||||
{"name": "概念股", "description": "热门概念板块", "color": "purple"},
|
||||
]
|
||||
|
||||
def init_default_groups():
|
||||
"""初始化默认分组(如果不存在)"""
|
||||
with get_session() as s:
|
||||
count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar()
|
||||
if count == 0:
|
||||
for idx, g in enumerate(DEFAULT_GROUPS):
|
||||
group = WatchlistGroup(
|
||||
name=g["name"],
|
||||
description=g["description"],
|
||||
color=g["color"],
|
||||
is_default=g.get("is_default", False),
|
||||
sort_order=idx
|
||||
)
|
||||
s.add(group)
|
||||
s.commit()
|
||||
print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)} 个")
|
||||
return True
|
||||
|
||||
def get_all_groups() -> List[Dict]:
|
||||
"""获取所有分组"""
|
||||
with get_session() as s:
|
||||
groups = s.execute(
|
||||
select(WatchlistGroup).order_by(WatchlistGroup.sort_order)
|
||||
).scalars().all()
|
||||
|
||||
result = []
|
||||
for g in groups:
|
||||
# 统计分组内股票数量
|
||||
count = s.execute(
|
||||
select(func.count()).select_from(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == g.id)
|
||||
).scalar()
|
||||
|
||||
result.append({
|
||||
"id": g.id,
|
||||
"name": g.name,
|
||||
"description": g.description,
|
||||
"color": g.color,
|
||||
"count": count,
|
||||
"is_default": g.is_default,
|
||||
"sort_order": g.sort_order
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def create_group(name: str, description: str = "", color: str = "blue") -> Dict:
|
||||
"""创建新分组"""
|
||||
with get_session() as s:
|
||||
# 获取当前最大排序号
|
||||
max_order = s.execute(
|
||||
select(func.max(WatchlistGroup.sort_order))
|
||||
).scalar() or 0
|
||||
|
||||
group = WatchlistGroup(
|
||||
name=name,
|
||||
description=description,
|
||||
color=color,
|
||||
sort_order=max_order + 1
|
||||
)
|
||||
s.add(group)
|
||||
s.commit()
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"id": group.id,
|
||||
"name": group.name
|
||||
}
|
||||
|
||||
def update_group(group_id: int, name: Optional[str] = None,
|
||||
description: Optional[str] = None, color: Optional[str] = None) -> Dict:
|
||||
"""更新分组信息"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
if name is not None:
|
||||
group.name = name
|
||||
if description is not None:
|
||||
group.description = description
|
||||
if color is not None:
|
||||
group.color = color
|
||||
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def delete_group(group_id: int) -> Dict:
|
||||
"""删除分组(同时删除分组内的股票)"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
if group.is_default:
|
||||
return {"ok": False, "msg": "默认分组不能删除"}
|
||||
|
||||
# 删除分组内的股票
|
||||
s.execute(
|
||||
WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id)
|
||||
)
|
||||
|
||||
# 删除分组
|
||||
s.delete(group)
|
||||
s.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
def reorder_groups(group_ids: List[int]) -> Dict:
|
||||
"""重新排序分组"""
|
||||
with get_session() as s:
|
||||
for idx, gid in enumerate(group_ids):
|
||||
group = s.get(WatchlistGroup, gid)
|
||||
if group:
|
||||
group.sort_order = idx
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict:
|
||||
"""获取分组内的股票列表"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
items = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id)
|
||||
.order_by(WatchlistItem.sort_order)
|
||||
).scalars().all()
|
||||
|
||||
codes = [item.code for item in items]
|
||||
|
||||
# 获取实时行情
|
||||
stocks = []
|
||||
if with_quotes and codes:
|
||||
quotes_data = svc.get_watchlist(codes)
|
||||
quotes_map = {s["code"]: s for s in quotes_data.get("list", [])}
|
||||
|
||||
for item in items:
|
||||
quote = quotes_map.get(item.code, {})
|
||||
stocks.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name or quote.get("name", ""),
|
||||
"price": quote.get("price", 0),
|
||||
"pct": quote.get("pct", 0),
|
||||
"change": quote.get("change", 0),
|
||||
"amount": quote.get("amount", 0),
|
||||
"note": item.note,
|
||||
"added_at": item.added_at.strftime("%Y-%m-%d")
|
||||
})
|
||||
else:
|
||||
for item in items:
|
||||
stocks.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name,
|
||||
"note": item.note,
|
||||
"added_at": item.added_at.strftime("%Y-%m-%d")
|
||||
})
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"color": group.color
|
||||
},
|
||||
"stocks": stocks
|
||||
}
|
||||
|
||||
def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict:
|
||||
"""添加股票到分组"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
# 检查是否已存在
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
return {"ok": False, "msg": "该股票已在分组中"}
|
||||
|
||||
# 获取股票名称
|
||||
sec = s.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
|
||||
# 获取当前最大排序号
|
||||
max_order = s.execute(
|
||||
select(func.max(WatchlistItem.sort_order))
|
||||
.where(WatchlistItem.group_id == group_id)
|
||||
).scalar() or 0
|
||||
|
||||
item = WatchlistItem(
|
||||
group_id=group_id,
|
||||
code=code,
|
||||
name=name,
|
||||
note=note,
|
||||
sort_order=max_order + 1
|
||||
)
|
||||
s.add(item)
|
||||
s.commit()
|
||||
|
||||
return {"ok": True, "id": item.id}
|
||||
|
||||
def remove_stock_from_group(item_id: int) -> Dict:
|
||||
"""从分组中移除股票"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
s.delete(item)
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def move_stock_to_group(item_id: int, target_group_id: int) -> Dict:
|
||||
"""将股票移动到另一个分组"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
target_group = s.get(WatchlistGroup, target_group_id)
|
||||
if not target_group:
|
||||
return {"ok": False, "msg": "目标分组不存在"}
|
||||
|
||||
# 检查目标分组是否已有该股票
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == target_group_id,
|
||||
WatchlistItem.code == item.code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
return {"ok": False, "msg": "目标分组已有该股票"}
|
||||
|
||||
item.group_id = target_group_id
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def batch_add_stocks(group_id: int, codes: List[str]) -> Dict:
|
||||
"""批量添加股票到分组"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
added = 0
|
||||
skipped = 0
|
||||
|
||||
for code in codes:
|
||||
# 检查是否已存在
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 获取股票名称
|
||||
sec = s.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
|
||||
item = WatchlistItem(
|
||||
group_id=group_id,
|
||||
code=code,
|
||||
name=name,
|
||||
sort_order=added
|
||||
)
|
||||
s.add(item)
|
||||
added += 1
|
||||
|
||||
s.commit()
|
||||
return {"ok": True, "added": added, "skipped": skipped}
|
||||
|
||||
def update_stock_note(item_id: int, note: str) -> Dict:
|
||||
"""更新股票备注"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
item.note = note
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def reorder_stocks(item_ids: List[int]) -> Dict:
|
||||
"""重新排序分组内的股票"""
|
||||
with get_session() as s:
|
||||
for idx, item_id in enumerate(item_ids):
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if item:
|
||||
item.sort_order = idx
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def search_stocks_across_groups(keyword: str) -> List[Dict]:
|
||||
"""跨分组搜索股票"""
|
||||
with get_session() as s:
|
||||
items = s.execute(
|
||||
select(WatchlistItem, WatchlistGroup)
|
||||
.join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id)
|
||||
.where(
|
||||
(WatchlistItem.code.like(f"%{keyword}%")) |
|
||||
(WatchlistItem.name.like(f"%{keyword}%"))
|
||||
)
|
||||
).all()
|
||||
|
||||
result = []
|
||||
for item, group in items:
|
||||
result.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name,
|
||||
"group_id": group.id,
|
||||
"group_name": group.name,
|
||||
"group_color": group.color,
|
||||
"note": item.note
|
||||
})
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user