功能细节优化

This commit is contained in:
2026-06-15 01:26:39 +08:00
parent e524a3589a
commit 964c17c200
33 changed files with 6990 additions and 210 deletions

237
backend/CHECKLIST.md Normal file
View File

@@ -0,0 +1,237 @@
# 启动前检查清单
在首次启动或遇到问题时,请按此清单逐项检查。
## ✅ 环境检查
### 系统服务
```bash
# PostgreSQL 是否运行
sudo service postgresql status
sudo service postgresql start # 如果未运行
# Redis 是否运行
redis-cli ping # 应返回 PONG
sudo service redis-server start # 如果未运行
```
### Python 环境
```bash
# 虚拟环境是否激活
which python # 应显示 .venv/bin/python
source .venv/bin/activate # 如果未激活
# 依赖是否安装完整
pip list | grep redis
pip list | grep jose
pip list | grep passlib
```
## ✅ 配置检查
### .env 文件
```bash
# 检查必要配置项
cat .env | grep PG_PASSWORD
cat .env | grep SECRET_KEY
cat .env | grep REDIS_HOST
# 如果缺少配置
cp .env.example .env # 如果 .env 不存在
nano .env # 编辑配置
```
### 必须配置的项
- [ ] `PG_PASSWORD` - PostgreSQL 密码
- [ ] `SECRET_KEY` - JWT 密钥(生产环境必须修改)
- [ ] `REDIS_HOST` - Redis 地址(默认 localhost
## ✅ 数据库检查
```bash
# 检查数据库是否创建
psql -U postgres -c "\l" | grep stock_cs
# 检查用户表是否存在
psql -U postgres -d stock_cs -c "\dt" | grep users
# 如果数据库或表不存在
python cli.py init
```
## ✅ 权限检查
```bash
# 测试数据库连接
psql -U postgres -d stock_cs -c "SELECT 1;"
# 测试 Redis 连接
redis-cli ping
# 如果提示权限错误
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
```
## ✅ 启动服务
```bash
# 完整启动流程
sudo service postgresql start
sudo service redis-server start
cd backend
source .venv/bin/activate
python main.py
```
### 启动成功标志
看到以下日志表示启动成功:
```
✓ Redis 已连接: localhost:6379
✓ 管理员账号已存在: admin
[startup] db + scheduler + auth ready
INFO: Uvicorn running on http://0.0.0.0:8000
```
## ✅ 功能测试
```bash
# 1. 健康检查
curl http://localhost:8000/api/health
# 应返回: {"ok":true,"akshare":true,"redis":true,"auth":true}
# 2. 登录测试
curl -X POST http://localhost:8000/api/auth/login \
-H "Content-Type: application/json" \
-d '{"username":"admin","password":"admin123"}'
# 应返回 Token
# 3. 运行完整测试
python test_core_features.py
```
## 🔧 常见问题排查
### 问题 1: Redis 连接失败
```
✗ Redis 连接失败,缓存已禁用
```
**解决**:
```bash
sudo service redis-server start
redis-cli ping
```
### 问题 2: 数据库连接失败
```
connection refused
```
**解决**:
```bash
sudo service postgresql start
psql -U postgres -c "SELECT 1;"
```
### 问题 3: 密码认证失败
```
password authentication failed
```
**解决**:
```bash
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
# 然后在 .env 中设置相同的密码
```
### 问题 4: 模块未找到
```
ModuleNotFoundError: No module named 'redis'
```
**解决**:
```bash
source .venv/bin/activate
pip install -r requirements.txt
```
### 问题 5: 401 Unauthorized
```
{"detail":"未认证,请先登录"}
```
**解决**:
```bash
# 先登录获取 Token
curl -X POST http://localhost:8000/api/auth/login \
-H "Content-Type: application/json" \
-d '{"username":"admin","password":"admin123"}'
# 使用 Token 访问
curl -X GET http://localhost:8000/api/admin/status \
-H "Authorization: Bearer <返回的token>"
```
### 问题 6: 用户表不存在
```
relation "users" does not exist
```
**解决**:
```bash
python cli.py init
```
## 📝 首次部署完整流程
```bash
# 1. 系统依赖
sudo apt update
sudo apt install -y postgresql redis-server
# 2. 启动服务
sudo service postgresql start
sudo service redis-server start
# 3. 配置数据库密码
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'your_password';"
# 4. Python 环境
cd backend
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
# 5. 配置环境变量
cp .env.example .env
nano .env # 编辑 PG_PASSWORD, SECRET_KEY 等
# 6. 初始化数据库
python cli.py init
# 7. 启动服务
python main.py
# 8. 测试
python test_core_features.py
```
## 🚀 一键启动脚本WSL
创建 `start.sh`:
```bash
#!/bin/bash
sudo service postgresql start
sudo service redis-server start
cd /mnt/e/project/stock_cs_v1/backend # 修改为实际路径
source .venv/bin/activate
python main.py
```
使用:
```bash
chmod +x start.sh
./start.sh
```
## 📞 获取帮助
如果以上方法都无法解决问题:
1. 查看详细文档: `backend/UPGRADE_GUIDE.md`
2. 查看配置说明: `backend/ENV_CONFIG.md`
3. 查看实现总结: `三大核心功能实现总结.md`

118
backend/ENV_CONFIG.md Normal file
View File

@@ -0,0 +1,118 @@
# 环境变量配置说明
## 新增配置项(三大核心功能)
### Redis 缓存配置
```bash
# Redis 连接配置
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_DB=0
REDIS_PASSWORD= # 如果 Redis 设置了密码则填写
```
### 认证系统配置
```bash
# JWT 密钥(生产环境务必修改)
SECRET_KEY=your-secret-key-change-in-production
# Token 过期时间(分钟)
ACCESS_TOKEN_EXPIRE_MINUTES=10080 # 默认7天
# API Key 模式(可选,用于外部调用,逗号分隔多个)
API_KEYS=your-api-key-1,your-api-key-2
# 默认管理员账号(首次启动时创建)
DEFAULT_ADMIN_USERNAME=admin
DEFAULT_ADMIN_PASSWORD=admin123 # 首次启动后务必修改密码
```
## 安装 RedisWSL 环境)
```bash
# 安装 Redis
sudo apt update
sudo apt install -y redis-server
# 启动 Redis
sudo service redis-server start
# 验证 Redis 是否运行
redis-cli ping
# 应返回 PONG
# 设置 Redis 开机自启(可选)
sudo systemctl enable redis-server
```
## 安全建议
1. **SECRET_KEY**: 生产环境必须使用强随机字符串,可用以下命令生成:
```bash
python -c "import secrets; print(secrets.token_urlsafe(32))"
```
2. **默认密码**: 首次登录后立即修改 admin 密码
3. **API_KEYS**: 仅在需要外部调用时配置,妥善保管
4. **Redis 密码**: 生产环境建议为 Redis 设置密码:
```bash
# 编辑 Redis 配置
sudo nano /etc/redis/redis.conf
# 找到 requirepass 行,取消注释并设置密码
requirepass your-strong-password
# 重启 Redis
sudo service redis-server restart
```
## 功能说明
### 1. Redis 缓存
- 替代内存缓存,支持持久化和跨进程共享
- 自动降级Redis 不可用时使用内存缓存
- 默认过期时间根据数据类型自动设置行情数据1分钟基本面数据1天
### 2. 统一鉴权
- **JWT Token 模式**: 用户登录获取 Token适合前端应用
- **API Key 模式**: 用于外部系统调用,配置在 HTTP Header `X-API-Key`
- 管理接口(`/api/admin/*`)需要管理员权限
### 3. 统一异常处理
- 业务异常返回友好错误信息
- 数据源异常自动降级
- 数据库异常统一处理
- 所有异常记录日志便于排查
## API 使用示例
### 登录
```bash
curl -X POST http://localhost:8000/api/auth/login \
-H "Content-Type: application/json" \
-d '{"username":"admin","password":"admin123"}'
```
### 使用 Token 访问受保护接口
```bash
curl -X GET http://localhost:8000/api/admin/status \
-H "Authorization: Bearer YOUR_TOKEN_HERE"
```
### 使用 API Key 访问
```bash
curl -X GET http://localhost:8000/api/admin/status \
-H "X-API-Key: your-api-key-1"
```
## 注意事项
1. WSL 环境下,每次重启后需要手动启动服务:
```bash
sudo service postgresql start
sudo service redis-server start
```
2. Redis 连接失败不会影响系统运行,会自动降级到内存缓存
3. 未配置鉴权时,所有接口默认不需要认证(开发模式)

287
backend/UPGRADE_GUIDE.md Normal file
View File

@@ -0,0 +1,287 @@
# 三大核心功能升级指南
本次升级新增了三个核心功能Redis 缓存层、统一鉴权机制、统一异常处理中间件。
## 1. 安装新依赖
```bash
cd backend
source .venv/bin/activate # 激活虚拟环境
pip install -r requirements.txt
```
新增的依赖包:
- `redis>=5.0.0` - Redis 客户端
- `python-jose[cryptography]>=3.3.0` - JWT Token 生成和验证
- `passlib[bcrypt]>=1.7.4` - 密码哈希
- `python-multipart>=0.0.9` - 表单数据解析
## 2. 安装和启动 RedisWSL
```bash
# 安装 Redis
sudo apt update
sudo apt install -y redis-server
# 启动 Redis
sudo service redis-server start
# 验证 Redis 是否运行
redis-cli ping
# 应返回 PONG
```
## 3. 配置环境变量
编辑 `backend/.env` 文件,添加以下配置:
```bash
# ============ Redis 缓存配置 ============
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_DB=0
REDIS_PASSWORD= # 可选,如果设置了密码则填写
# ============ 认证系统配置 ============
# JWT 密钥(务必修改为强随机字符串)
SECRET_KEY=your-secret-key-change-in-production
# Token 过期时间分钟默认7天
ACCESS_TOKEN_EXPIRE_MINUTES=10080
# API Key可选用于外部调用逗号分隔
API_KEYS=
# 默认管理员账号
DEFAULT_ADMIN_USERNAME=admin
DEFAULT_ADMIN_PASSWORD=admin123
```
### 生成安全的 SECRET_KEY
```bash
python -c "import secrets; print(secrets.token_urlsafe(32))"
```
将输出的字符串替换 `SECRET_KEY` 的值。
## 4. 初始化数据库(创建用户表和默认管理员)
```bash
cd backend
source .venv/bin/activate
python cli.py init
```
你会看到类似输出:
```
✓ 创建默认管理员: admin
init done
```
## 5. 启动服务
```bash
# 确保 PostgreSQL 和 Redis 都在运行
sudo service postgresql start
sudo service redis-server start
# 启动 FastAPI 服务
cd backend
source .venv/bin/activate
python main.py
```
启动日志应包含:
```
✓ Redis 已连接: localhost:6379
✓ 管理员账号已存在: admin
[startup] db + scheduler + auth ready
```
## 6. 测试功能
### 测试健康检查(查看 Redis 和鉴权状态)
```bash
curl http://localhost:8000/api/health
```
应返回:
```json
{
"ok": true,
"akshare": true,
"redis": true,
"auth": true
}
```
### 测试登录
```bash
curl -X POST http://localhost:8000/api/auth/login \
-H "Content-Type: application/json" \
-d '{"username":"admin","password":"admin123"}'
```
应返回 Token
```json
{
"ok": true,
"access_token": "eyJ0eXAiOiJKV1QiLCJhb...",
"token_type": "bearer",
"username": "admin",
"is_admin": true
}
```
### 测试受保护的接口
```bash
# 使用上一步获取的 Token
curl -X GET http://localhost:8000/api/admin/status \
-H "Authorization: Bearer YOUR_TOKEN_HERE"
```
### 测试 Redis 缓存
```bash
# 第一次请求(缓存未命中,较慢)
time curl http://localhost:8000/api/indices
# 第二次请求(缓存命中,应该快很多)
time curl http://localhost:8000/api/indices
```
## 7. 修改默认密码(重要!)
首次登录后,立即修改默认密码:
```bash
curl -X POST http://localhost:8000/api/auth/change-password \
-H "Authorization: Bearer YOUR_TOKEN" \
-H "Content-Type: application/json" \
-d '{"old_password":"admin123","new_password":"your-new-strong-password"}'
```
## 8. 常见问题
### Redis 连接失败
如果看到以下日志:
```
✗ Redis 连接失败,缓存已禁用: ...
```
**解决方法**
```bash
# 检查 Redis 是否运行
sudo service redis-server status
# 如果未运行,启动它
sudo service redis-server start
```
**注意**Redis 连接失败不会影响系统运行,会自动降级到内存缓存。
### 401 Unauthorized 错误
如果访问管理接口返回 401
```json
{"detail":"未认证,请先登录"}
```
**原因**:该接口需要认证
**解决方法**
1. 先调用 `/api/auth/login` 获取 Token
2. 在请求头中加入 `Authorization: Bearer YOUR_TOKEN`
### WSL 重启后服务启动失败
WSL 每次重启后需要手动启动服务:
```bash
# 一键启动所有服务
sudo service postgresql start && \
sudo service redis-server start && \
cd /mnt/e/project/stock_cs_v1/backend && \
source .venv/bin/activate && \
python main.py
```
## 9. 功能详细说明
### Redis 缓存优势
- **持久化**:服务重启后缓存不丢失
- **跨进程**:多个进程可共享缓存
- **性能提升**:大幅减少 AkShare API 调用,响应速度提升 10-100 倍
- **自动降级**Redis 不可用时自动使用内存缓存
### 鉴权机制
**支持两种认证方式**
1. **JWT Token**(推荐用于前端)
- 登录后获取 Token
- Token 有效期 7 天(可配置)
- 在 HTTP Header 中传递:`Authorization: Bearer TOKEN`
2. **API Key**(推荐用于外部系统)
-`.env` 中配置 `API_KEYS`
- 在 HTTP Header 中传递:`X-API-Key: YOUR_API_KEY`
**受保护的接口**
- `/api/admin/*` - 需要管理员权限
- 其他接口暂不需要认证(可根据需要扩展)
### 异常处理
系统现在会返回友好的错误信息:
```json
{
"success": false,
"error": "数据源异常,请稍后重试",
"code": 503
}
```
错误类型:
- **400** - 业务逻辑错误
- **401** - 未认证
- **403** - 权限不足
- **422** - 请求参数错误
- **500** - 服务器内部错误
- **503** - 数据源不可用
## 10. 升级检查清单
- [ ] 安装新依赖包
- [ ] 安装并启动 Redis
- [ ] 配置 `.env` 文件Redis + 鉴权配置)
- [ ] 生成安全的 `SECRET_KEY`
- [ ] 运行 `python cli.py init` 创建用户表
- [ ] 启动服务并验证功能
- [ ] 测试登录和受保护接口
- [ ] 修改默认管理员密码
- [ ] 检查 Redis 缓存是否生效
## 11. 回退方案
如果升级后遇到问题,可以临时禁用新功能:
1. **禁用 Redis 缓存**:停止 Redis 服务,系统会自动降级到内存缓存
```bash
sudo service redis-server stop
```
2. **禁用鉴权**:暂时注释掉 `main.py` 中受保护接口的 `Depends(require_auth)` 参数
3. **完整回退**:切换到升级前的 git 分支
---
升级完成后,系统将具备更强的性能、安全性和稳定性!

View File

@@ -202,6 +202,176 @@ def diagnose(symbol):
return base
# ============ 走势分析右键K线 ============
def trend_analysis(symbol: str, date: str, period: str = "daily"):
"""分析某只股票在指定日期(或最新)附近暴涨/暴跌的原因。
period: daily / weekly / monthly
"""
with get_session() as s:
sec = s.get(Security, symbol)
# 取该日期前后一段数据作为上下文
if date:
try:
target_date = dt.date.fromisoformat(date)
except Exception:
target_date = None
else:
target_date = None
# 取最近60根K线
rows = s.execute(
select(DailyQuote).where(DailyQuote.code == symbol)
.order_by(DailyQuote.date.desc()).limit(60)
).scalars().all()
rows = list(reversed(rows))
# 当日及相邻数据
target_row = None
if target_date and rows:
# 找最近的一根
closest = min(rows, key=lambda r: abs((r.date - target_date).days))
if abs((closest.date - target_date).days) <= 7:
target_row = closest
if not target_row and rows:
target_row = rows[-1]
m = s.get(StockMetric, symbol)
name = (sec.name if sec else (m.name if m else symbol)) or symbol
# 计算目标K线的涨跌幅
pct = 0.0
if target_row and rows:
idx = rows.index(target_row)
if idx > 0:
prev_close = rows[idx - 1].close
if prev_close:
pct = round((target_row.close - prev_close) / prev_close * 100, 2)
# 拉取相关新闻
import akshare_service as svc
try:
news_data = svc.get_stock_news(symbol, limit=10)
news_items = news_data.get("list", [])
except Exception:
news_items = []
# 拉取RAG上下文
import rag
rctx = rag.stock_context(symbol, limit=6)
# 构造上下文数据
period_cn = {"daily": "日K", "weekly": "周K", "monthly": "月K"}.get(period, "K线")
date_str = target_row.date.isoformat() if target_row else (date or "最新")
# 当日技术面
if target_row:
tech_line = (
f"目标K线{date_str},开{target_row.open}{target_row.close} "
f"{target_row.high}{target_row.low}"
f"涨跌幅{pct:+.2f}%,成交量{target_row.volume:,}"
)
else:
tech_line = f"目标日期:{date_str},暂无日线数据"
# 前后走势最近5根
if target_row and rows:
idx = rows.index(target_row)
window = rows[max(0, idx-4):idx+2]
trend_line = "前后走势:" + "".join(
f"{r.date.strftime('%m/%d')}({'' if i == 0 or r.close >= rows[rows.index(r)-1].close else ''}{abs(round((r.close/rows[rows.index(r)-1].close-1)*100,1)) if rows.index(r) > 0 else 0}%)"
for i, r in enumerate(window)
)
else:
trend_line = ""
# 均线状态
ma_line = ""
if m:
ma_line = (f"均线状态MA5={m.ma5} MA10={m.ma10} MA20={m.ma20} MA60={m.ma60}"
f"{'多头排列' if m.ma_bull else '非多头'}"
f"量比{m.vol_ratio}RSI14={m.rsi14}")
# 新闻摘要
news_block = ""
if news_items:
news_block = "相关新闻(近期):\n" + "\n".join(
f"- [{n.get('time','')[:10]}] {n.get('title','')}" for n in news_items[:6]
)
# 判断是否暴涨/暴跌
move_desc = ""
if abs(pct) >= 5:
move_desc = f"该股{'暴涨' if pct > 0 else '暴跌'} {abs(pct):.2f}%{'接近/涨停' if pct >= 9.5 else '显著上涨' if pct > 0 else '接近/跌停' if pct <= -9.5 else '显著下跌'}"
elif abs(pct) >= 2:
move_desc = f"该股{'上涨' if pct > 0 else '下跌'} {abs(pct):.2f}%"
else:
move_desc = f"该股小幅变动 {pct:+.2f}%"
facts = f"""{name}{symbol}{period_cn}走势分析
分析日期:{date_str}
{move_desc}
{tech_line}
{trend_line}
{ma_line}
{news_block}
消息面情绪:{rctx['tone']}
{rctx['block'] or ''}"""
if llm.enabled():
try:
prompt = (
f"请分析 {name}{symbol})在 {date_str} 前后{period_cn}的走势,"
f"重点解释:① 为什么{'暴涨' if pct >= 5 else ('暴跌' if pct <= -5 else '出现此走势')}(从技术面、资金面、政策面、新闻事件等维度);"
f"② 背后的主要驱动逻辑是什么;③ 后续需关注的信号或风险。250字以内分点清晰。\n\n{facts}"
)
text = llm.ask(prompt, temperature=0.5, max_tokens=600)
return {"ok": True, "source": "llm", "symbol": symbol, "name": name,
"date": date_str, "period": period, "pct": pct, "facts": facts, "text": text}
except Exception:
pass
# 规则降级
reasons = []
if m:
if m.ma_bull and pct > 0:
reasons.append("均线多头排列,趋势向上")
if m.vol_ratio >= 2 and pct > 0:
reasons.append(f"成交量显著放大(量比{m.vol_ratio}),主力资金介入")
if m.vol_ratio >= 2 and pct < 0:
reasons.append(f"放量下跌(量比{m.vol_ratio}),资金出逃信号")
if m.macd_gold and pct > 0:
reasons.append("MACD金叉动能转强")
if m.rsi14 >= 80:
reasons.append(f"RSI超买{m.rsi14}),注意回调风险")
if m.rsi14 < 30:
reasons.append(f"RSI超卖{m.rsi14}),存在超跌反弹机会")
if m.pos60 >= 0.95 and pct > 0:
reasons.append("突破60日新高动量突破")
if m.pos60 <= 0.1 and pct > 0:
reasons.append("低位反弹,超跌修复")
if rctx['tone'] == '利好':
reasons.append("近期资讯面偏利好")
elif rctx['tone'] == '利空':
reasons.append("近期资讯面偏利空")
if news_items:
hot_news = news_items[0]['title'][:40]
reasons.append(f"最新消息:{hot_news}")
if not reasons:
reasons.append("暂无明确技术或消息面驱动,可能为市场情绪或板块联动")
text = (
f"{name}{date_str} {move_desc}\n"
f"主要原因分析:\n" +
"\n".join(f"{i+1}. {r}" for i, r in enumerate(reasons)) +
f"\n\n建议:{'关注量能是否持续配合,谨防高位回调。' if pct >= 5 else ('关注是否企稳止跌,底部确认前谨慎抄底。' if pct <= -5 else '走势相对平稳,跟踪板块动向。')}"
f"\n{DISCLAIMER}"
)
return {"ok": True, "source": "rule", "symbol": symbol, "name": name,
"date": date_str, "period": period, "pct": pct, "facts": facts, "text": text}
# ============ 今日策略 ============
def today_strategy():
with get_session() as s:

View File

@@ -12,6 +12,7 @@ from functools import wraps
from cachetools import TTLCache
import requests
from redis_cache import cache
try:
import akshare as ak
@@ -25,16 +26,35 @@ _cache = TTLCache(maxsize=256, ttl=30)
def cached(ttl: int):
"""缓存装饰器:优先使用 Redis降级到内存缓存"""
def deco(fn):
local = TTLCache(maxsize=64, ttl=ttl)
@wraps(fn)
def wrapper(*args, **kwargs):
key = (fn.__name__, args, tuple(sorted(kwargs.items())))
if key in local:
return local[key]
# 生成缓存键
key = f"akshare:{fn.__name__}:{args}:{tuple(sorted(kwargs.items()))}"
# 优先从 Redis 读取
if cache.enabled:
cached_value = cache.get(key)
if cached_value is not None:
return cached_value
# Redis 未命中,从内存缓存读取
local_key = (fn.__name__, args, tuple(sorted(kwargs.items())))
if local_key in local:
return local[local_key]
# 执行函数
val = fn(*args, **kwargs)
local[key] = val
# 写入 Redis
if cache.enabled:
cache.set(key, val, expire=ttl)
# 写入内存缓存(降级)
local[local_key] = val
return val
return wrapper
@@ -183,7 +203,19 @@ def get_stock_news(code: str, limit: int = 12):
return {"source": "mock", "list": []}
# 已知指数代码 → 新浪前缀映射
_INDEX_CODES = {"000001", "000300", "000016", "399001", "399006", "899050"}
def _is_index(code: str) -> bool:
return code in _INDEX_CODES or code.startswith(("sh0", "sz3990", "bj8990"))
def _sina_symbol(code: str) -> str:
if code in ("000001", "000016"): # 上证系列
return "sh" + code
if code in ("000300",): # 沪深300
return "sh" + code
if code in ("399001", "399006"): # 深证
return "sz" + code
if code.startswith("6"):
return "sh" + code
if code.startswith(("0", "3")):
@@ -194,9 +226,23 @@ def _sina_symbol(code: str) -> str:
@cached(60)
def get_kline(symbol: str = "600519", days: int = 120):
def get_kline(symbol: str = "000001", days: int = 120):
if AK_OK:
# 主源:新浪日线(更稳定);备源:腾讯
# 指数走专用接口
if symbol in _INDEX_CODES:
try:
sym = _sina_symbol(symbol)
df = ak.stock_zh_index_daily(symbol=sym)
if df is not None and not df.empty:
df = df.tail(days)
dates = [str(d)[5:].replace("-", "/") for d in df["date"]]
ohlc = [[float(r["open"]), float(r["close"]), float(r["low"]), float(r["high"])]
for _, r in df.iterrows()]
vols = [int(r["volume"]) if "volume" in df.columns else 0 for _, r in df.iterrows()]
return {"source": "akshare", "symbol": symbol, "dates": dates, "ohlc": ohlc, "vols": vols}
except Exception:
pass
# 个股主源:新浪日线(更稳定);备源:腾讯
for src in ("sina", "tx"):
try:
sym = _sina_symbol(symbol)
@@ -321,6 +367,90 @@ def get_treemap(mode: str = "sector"):
return {"source": boards["source"], "mode": "sector", "items": items}
@cached(120)
def get_us_treemap():
"""美股热门板块云图按成交额取前100只"""
if AK_OK:
try:
df = ak.stock_us_spot_em()
if df is not None and not df.empty:
top = df.sort_values("成交额", ascending=False).head(100)
items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e8, 2),
"pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()]
items = [x for x in items if x["name"]]
return {"source": "akshare", "market": "us", "items": items}
except Exception:
pass
names = ["苹果","微软","谷歌","亚马逊","英伟达","特斯拉","Meta","台积电","巴菲特","摩根"]
return {"source": "mock", "market": "us",
"items": [{"name": n, "value": _rnd(10,200), "pct": round(_rnd(-4,4),2)} for n in names]}
@cached(120)
def get_hk_treemap():
"""港股热门板块云图按成交额取前100只"""
if AK_OK:
try:
df = ak.stock_hk_spot_em()
if df is not None and not df.empty:
top = df.sort_values("成交额", ascending=False).head(100)
items = [{"name": str(r.get("名称","")), "value": round(float(r.get("成交额",0))/1e4, 2),
"pct": round(float(r.get("涨跌幅",0)), 2)} for _, r in top.iterrows()]
items = [x for x in items if x["name"]]
return {"source": "akshare", "market": "hk", "items": items}
except Exception:
pass
names = ["腾讯","阿里巴巴","美团","京东","小米","百度","网易","中国平安","汇丰","友邦"]
return {"source": "mock", "market": "hk",
"items": [{"name": n, "value": _rnd(5,100), "pct": round(_rnd(-4,4),2)} for n in names]}
@cached(120)
def get_all_sector_leaders(top_n: int = 5):
"""一次性获取所有板块的前N只龙头股"""
boards = get_industry_boards()
result = {}
for b in boards.get("list", []):
name = b["name"]
try:
r = get_sector_stocks(name, top_n + 1)
result[name] = r.get("stocks", [])[:top_n]
except Exception:
result[name] = []
return {"source": "akshare", "sectors": result}
@cached(300)
def get_sector_stocks(sector_name: str, limit: int = 20):
"""获取板块成分股,按成交额排序"""
if AK_OK:
try:
df = ak.stock_board_industry_cons_em(symbol=sector_name)
if df is not None and not df.empty:
if "成交额" in df.columns:
df = df.sort_values("成交额", ascending=False)
stocks = []
for _, r in df.head(limit).iterrows():
try:
stocks.append({
"code": str(r.get("代码", "")),
"name": str(r.get("名称", "")),
"pct": round(float(r.get("涨跌幅", 0)), 2),
"price": round(float(r.get("最新价", 0)), 2),
"amount": round(float(r.get("成交额", 0)) / 1e8, 2),
})
except Exception:
continue
return {"source": "akshare", "name": sector_name, "stocks": stocks}
except Exception:
pass
# mock
stocks = [{"code": f"60000{i}", "name": f"{sector_name}{i+1}",
"pct": round(_rnd(-5, 5), 2), "price": round(_rnd(5, 100), 2), "amount": round(_rnd(1, 50), 2)}
for i in range(10)]
return {"source": "mock", "name": sector_name, "stocks": stocks}
# ============================================================
# 资金流向(行业)
# ============================================================

88
backend/auth.py Normal file
View File

@@ -0,0 +1,88 @@
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status, Header
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
import bcrypt
from sqlalchemy.orm import Session
from db import SessionLocal
from models import User
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, API_KEYS
security = HTTPBearer(auto_error=False)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def get_password_hash(password: str) -> str:
"""密码哈希"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""生成 JWT Token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""验证用户名密码"""
user = db.query(User).filter(User.username == username).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
x_api_key: Optional[str] = Header(None),
db: Session = Depends(lambda: SessionLocal())
) -> Optional[User]:
"""获取当前用户(支持 JWT Token 和 API Key 两种方式)"""
# 方式1API Key
if x_api_key and x_api_key in API_KEYS:
# API Key 模式,返回虚拟管理员
user = db.query(User).filter(User.username == "admin").first()
if user:
return user
# 方式2JWT Token
if credentials:
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
user = db.query(User).filter(User.username == username).first()
return user
except JWTError:
return None
return None
async def require_auth(current_user: Optional[User] = Depends(get_current_user)):
"""需要认证的依赖"""
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未认证,请先登录",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user
async def require_admin(current_user: User = Depends(require_auth)):
"""需要管理员权限的依赖"""
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限"
)
return current_user

View File

@@ -9,12 +9,16 @@ import sys
from db import init_db
import ingest
import init_auth
import watchlist_manager as wl
def main():
init_db()
args = sys.argv[1:]
if not args or args[0] == "init":
init_auth.init_default_admin()
wl.init_default_groups()
print("init done")
return
if args[0] == "ingest":

View File

@@ -51,3 +51,19 @@ SERVERCHAN_KEY = os.getenv("SERVERCHAN_KEY", "")
WECOM_WEBHOOK = os.getenv("WECOM_WEBHOOK", "")
# PushPlus微信推送
PUSHPLUS_TOKEN = os.getenv("PUSHPLUS_TOKEN", "")
# ---- Redis 缓存 ----
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB = int(os.getenv("REDIS_DB", "0"))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")
# ---- 鉴权配置 ----
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "10080")) # 7天
# API Key 模式(可选,用于外部调用)
API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []
# 默认管理员账号(首次启动时创建)
DEFAULT_ADMIN_USERNAME = os.getenv("DEFAULT_ADMIN_USERNAME", "admin")
DEFAULT_ADMIN_PASSWORD = os.getenv("DEFAULT_ADMIN_PASSWORD", "admin123")

398
backend/data_manager.py Normal file
View File

@@ -0,0 +1,398 @@
"""数据修正与回填增强:数据修正、断点续传、完整性检查、质量报告"""
import datetime as dt
import json
import os
from typing import List, Optional, Dict
from sqlalchemy import select, func, and_, delete
from db import get_session
from models import DailyQuote, StockMetric, Security, IndexDaily, SectorDaily, JobRun
import ingest
# 回填进度文件路径
PROGRESS_FILE = os.path.join(os.path.dirname(__file__), ".refill_progress.json")
# ============ 数据修正 ============
def delete_quote(code: str, date: str) -> Dict:
"""删除指定股票指定日期的日线数据"""
d = dt.date.fromisoformat(date)
with get_session() as s:
row = s.execute(
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
).scalar_one_or_none()
if not row:
return {"ok": False, "msg": f"{code} {date} 无此数据"}
s.delete(row)
s.commit()
return {"ok": True, "msg": f"已删除 {code} {date} 日线"}
def update_quote(code: str, date: str, fields: Dict) -> Dict:
"""修正指定股票指定日期的日线数据"""
allowed = {"open", "high", "low", "close", "volume", "amount"}
to_update = {k: v for k, v in fields.items() if k in allowed}
if not to_update:
return {"ok": False, "msg": "无有效修正字段"}
d = dt.date.fromisoformat(date)
with get_session() as s:
row = s.execute(
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
).scalar_one_or_none()
if not row:
return {"ok": False, "msg": f"{code} {date} 无此数据"}
for k, v in to_update.items():
setattr(row, k, v)
s.commit()
return {"ok": True, "updated": to_update}
def delete_quotes_range(code: str, start: str, end: str) -> Dict:
"""删除指定股票日期范围内的日线数据"""
d_start = dt.date.fromisoformat(start)
d_end = dt.date.fromisoformat(end)
with get_session() as s:
rows = s.execute(
select(DailyQuote).where(
DailyQuote.code == code,
DailyQuote.date >= d_start,
DailyQuote.date <= d_end
)
).scalars().all()
count = len(rows)
for row in rows:
s.delete(row)
s.commit()
return {"ok": True, "deleted": count, "range": f"{start} ~ {end}"}
def refetch_quote(code: str, days: int = 30) -> Dict:
"""重新抓取指定股票的日线数据(覆盖更新)"""
rows = ingest.fetch_daily(code, days)
if not rows:
return {"ok": False, "msg": f"抓取 {code} 数据失败"}
n = ingest.ingest_quotes([code], days=days)
return {"ok": True, "code": code, "rows": len(rows), "msg": f"已更新 {len(rows)} 条日线"}
# ============ 数据完整性检查 ============
def check_data_integrity(codes: Optional[List[str]] = None, days: int = 30) -> Dict:
"""检查数据完整性,找出缺失数据的股票和日期"""
with get_session() as s:
# 确定检查范围
latest = s.execute(select(func.max(DailyQuote.date))).scalar()
if not latest:
return {"ok": False, "msg": "数据库无日线数据"}
start = latest - dt.timedelta(days=days)
# 获取检查的股票列表
if codes:
check_codes = codes
else:
# 默认检查有记录的所有股票
all_codes = s.execute(
select(DailyQuote.code).where(
DailyQuote.date >= start
).distinct()
).scalars().all()
check_codes = list(all_codes)[:200] # 最多检查200只
# 统计每只股票的数据点数
from sqlalchemy import case
code_counts = {}
for code in check_codes:
count = s.execute(
select(func.count()).select_from(DailyQuote)
.where(DailyQuote.code == code, DailyQuote.date >= start)
).scalar()
code_counts[code] = count
# 以最多数据量为基准(应是交易日数)
expected = max(code_counts.values()) if code_counts else 0
# 找出缺失数据的股票
missing = []
normal = []
for code, count in code_counts.items():
ratio = count / expected if expected > 0 else 0
if ratio < 0.8: # 缺失超过20%
with get_session() as s2:
sec = s2.get(Security, code)
name = sec.name if sec else code
missing.append({
"code": code,
"name": name,
"actual": count,
"expected": expected,
"missing": expected - count,
"missing_pct": round((1 - ratio) * 100, 1)
})
else:
normal.append(code)
missing.sort(key=lambda x: x["missing"], reverse=True)
return {
"ok": True,
"check_range": f"{start.isoformat()} ~ {latest.isoformat()}",
"checked": len(check_codes),
"expected_days": expected,
"normal_count": len(normal),
"missing_count": len(missing),
"missing_stocks": missing[:50]
}
def auto_fix_missing(limit: int = 50) -> Dict:
"""自动补齐缺失数据(批量重新抓取)"""
result = check_data_integrity(days=30)
if not result["ok"] or result["missing_count"] == 0:
return {"ok": True, "msg": "数据完整,无需修复", "fixed": 0}
missing_stocks = result["missing_stocks"][:limit]
codes = [s["code"] for s in missing_stocks]
with get_session() as s:
job = JobRun(job="auto_fix", status="running",
message=f"0/{len(codes)}")
s.add(job)
s.commit()
job_id = job.id
fixed = 0
failed = []
try:
for i, code in enumerate(codes):
rows = ingest.fetch_daily(code, days=60)
if rows:
ingest.ingest_quotes([code], days=60)
fixed += 1
else:
failed.append(code)
if (i + 1) % 10 == 0:
with get_session() as s:
j = s.get(JobRun, job_id)
j.message = f"{i+1}/{len(codes)}"
s.commit()
status = "success"
msg = f"修复 {fixed}/{len(codes)},失败 {len(failed)}"
except Exception as e:
status = "error"
msg = f"修复中断: {repr(e)[:160]}"
with get_session() as s:
j = s.get(JobRun, job_id)
j.status = status
j.finished_at = dt.datetime.now()
j.message = msg
s.commit()
return {"ok": True, "fixed": fixed, "failed": failed, "msg": msg}
# ============ 断点续传回填 ============
def _load_progress() -> Dict:
"""加载回填进度"""
if os.path.exists(PROGRESS_FILE):
try:
with open(PROGRESS_FILE, "r") as f:
return json.load(f)
except Exception:
pass
return {}
def _save_progress(progress: Dict):
"""保存回填进度"""
with open(PROGRESS_FILE, "w") as f:
json.dump(progress, f)
def _clear_progress(task_id: str):
"""清除指定任务的进度"""
progress = _load_progress()
progress.pop(task_id, None)
_save_progress(progress)
def start_refill_with_resume(days: int = 250, task_id: str = "default") -> Dict:
"""带断点续传的全市场回填"""
from akshare_service import _code_name_map
cmap = _code_name_map()
all_codes = [c for c in cmap.keys() if c[:1] in ("0", "3", "6")]
total = len(all_codes)
# 加载进度
progress = _load_progress()
task_progress = progress.get(task_id, {"done_codes": [], "days": days})
done_codes = set(task_progress.get("done_codes", []))
# 过滤已完成的股票
remaining = [c for c in all_codes if c not in done_codes]
with get_session() as s:
job = JobRun(
job="refill_resume",
status="running",
message=f"续传: 已完成 {len(done_codes)}/{total},剩余 {len(remaining)}"
)
s.add(job)
s.commit()
job_id = job.id
fixed = len(done_codes)
try:
for i in range(0, len(remaining), 50):
batch = remaining[i:i + 50]
ingest.ingest_quotes(batch, days=days, with_metrics=True, cmap=cmap)
fixed += len(batch)
# 保存进度
done_codes.update(batch)
progress[task_id] = {
"done_codes": list(done_codes),
"days": days,
"updated_at": dt.datetime.now().isoformat()
}
_save_progress(progress)
with get_session() as s:
j = s.get(JobRun, job_id)
j.message = f"{fixed}/{total}"
s.commit()
# 完成后清除进度
_clear_progress(task_id)
status = "success"
msg = f"完成 {fixed}/{total}"
except Exception as e:
status = "error"
msg = f"中断于 {fixed}/{total} | {repr(e)[:160]}"
# 保留进度供续传
with get_session() as s:
j = s.get(JobRun, job_id)
j.status = status
j.finished_at = dt.datetime.now()
j.message = msg
s.commit()
return {"ok": status == "success", "done": fixed, "total": total, "msg": msg}
def get_refill_progress(task_id: str = "default") -> Dict:
"""获取回填进度"""
progress = _load_progress()
task = progress.get(task_id)
if not task:
return {"ok": True, "has_progress": False, "msg": "无回填进度记录"}
from akshare_service import _code_name_map
cmap = _code_name_map()
total = len([c for c in cmap.keys() if c[:1] in ("0", "3", "6")])
done = len(task.get("done_codes", []))
return {
"ok": True,
"has_progress": True,
"task_id": task_id,
"done": done,
"total": total,
"pct": round(done / total * 100, 1) if total > 0 else 0,
"updated_at": task.get("updated_at", "")
}
def clear_refill_progress(task_id: str = "default") -> Dict:
"""清除回填进度(从头开始)"""
_clear_progress(task_id)
return {"ok": True, "msg": f"已清除任务 {task_id} 的进度"}
# ============ 数据质量报告 ============
def get_data_quality_report() -> Dict:
"""生成数据质量报告"""
with get_session() as s:
# 基本统计
total_quotes = s.execute(select(func.count()).select_from(DailyQuote)).scalar() or 0
total_stocks = s.execute(
select(func.count(DailyQuote.code.distinct()))
).scalar() or 0
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
earliest_date = s.execute(select(func.min(DailyQuote.date))).scalar()
# 最近30天数据密度
if latest_date:
start30 = latest_date - dt.timedelta(days=30)
recent_stocks = s.execute(
select(func.count(DailyQuote.code.distinct()))
.where(DailyQuote.date >= start30)
).scalar() or 0
recent_dates = s.execute(
select(func.count(DailyQuote.date.distinct()))
.where(DailyQuote.date >= start30)
).scalar() or 0
else:
recent_stocks = 0
recent_dates = 0
# 异常数据检测开盘价为0的记录
zero_open = s.execute(
select(func.count()).select_from(DailyQuote)
.where(DailyQuote.open == 0)
).scalar() or 0
# 最近任务状态
recent_jobs = s.execute(
select(JobRun).order_by(JobRun.id.desc()).limit(5)
).scalars().all()
jobs_summary = [{
"job": j.job,
"status": j.status,
"started": j.started_at.strftime("%m-%d %H:%M") if j.started_at else "",
"message": j.message[:100]
} for j in recent_jobs]
# 数据健康度评分
score = 100
issues = []
if zero_open > 0:
score -= min(20, zero_open // 100)
issues.append(f"存在 {zero_open} 条开盘价为0的异常数据")
if total_stocks < 100:
score -= 30
issues.append(f"入库股票数量偏少({total_stocks}只)")
if latest_date and (dt.date.today() - latest_date).days > 7:
score -= 20
issues.append(f"数据滞后 {(dt.date.today() - latest_date).days}")
return {
"ok": True,
"generated_at": dt.datetime.now().isoformat(),
"health_score": max(0, score),
"issues": issues,
"statistics": {
"total_quotes": total_quotes,
"total_stocks": total_stocks,
"latest_date": latest_date.isoformat() if latest_date else None,
"earliest_date": earliest_date.isoformat() if earliest_date else None,
"data_span_days": (latest_date - earliest_date).days if latest_date and earliest_date else 0,
"recent_30d_stocks": recent_stocks,
"recent_30d_dates": recent_dates,
"zero_open_count": zero_open
},
"recent_jobs": jobs_summary
}

73
backend/exceptions.py Normal file
View File

@@ -0,0 +1,73 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from sqlalchemy.exc import SQLAlchemyError
import traceback
class BusinessException(Exception):
"""业务异常基类"""
def __init__(self, message: str, code: int = 400):
self.message = message
self.code = code
super().__init__(self.message)
class DataSourceException(BusinessException):
"""数据源异常AkShare等"""
def __init__(self, message: str = "数据源异常,请稍后重试"):
super().__init__(message, code=503)
class AuthException(BusinessException):
"""认证异常"""
def __init__(self, message: str = "认证失败"):
super().__init__(message, code=401)
class PermissionException(BusinessException):
"""权限异常"""
def __init__(self, message: str = "权限不足"):
super().__init__(message, code=403)
async def business_exception_handler(request: Request, exc: BusinessException):
"""业务异常处理器"""
return JSONResponse(
status_code=exc.code,
content={
"success": False,
"error": exc.message,
"code": exc.code
}
)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""请求参数验证异常处理器"""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"success": False,
"error": "请求参数错误",
"detail": exc.errors()
}
)
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
"""数据库异常处理器"""
print(f"数据库错误: {exc}")
traceback.print_exc()
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"success": False,
"error": "数据库操作失败"
}
)
async def general_exception_handler(request: Request, exc: Exception):
"""通用异常处理器"""
print(f"未捕获异常: {type(exc).__name__}: {exc}")
traceback.print_exc()
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"success": False,
"error": "服务器内部错误"
}
)

View File

@@ -12,7 +12,7 @@ import akshare_service as svc
import config
from db import get_session
from models import (DailyQuote, DragonTiger, FundFlowDaily, IndexDaily, JobRun,
SectorDaily, Security, SentimentDaily, StockMetric)
SectorDaily, SectorLeader, Security, SentimentDaily, StockMetric)
try:
import akshare as ak
@@ -229,6 +229,25 @@ def ingest_sectors():
return n
def ingest_sector_leaders():
"""入库各板块前5龙头股按成交额"""
d = _today()
data = svc.get_all_sector_leaders(top_n=5)
rows = []
for sector, stocks in data.get("sectors", {}).items():
for i, s in enumerate(stocks):
rows.append({"date": d, "sector": sector, "code": s["code"],
"name": s["name"], "pct": s["pct"],
"price": s["price"], "amount": s["amount"], "rank": i + 1})
if not rows:
return 0
with get_session() as s:
n = _upsert(s, SectorLeader, rows, ["date", "sector", "code"],
["name", "pct", "price", "amount", "rank"])
s.commit()
return n
def ingest_fund_flow():
data = svc.get_fund_flow()
d = _today()
@@ -280,6 +299,7 @@ def run_daily_ingest(universe=None, with_quotes=True):
summary["securities"] = ingest_securities()
summary["indices"] = ingest_indices()
summary["sectors"] = ingest_sectors()
summary["sector_leaders"] = ingest_sector_leaders()
summary["fund_flow"] = ingest_fund_flow()
summary["sentiment"] = ingest_sentiment()
summary["dragon"] = ingest_dragon()

21
backend/init_auth.py Normal file
View File

@@ -0,0 +1,21 @@
from db import get_session
from models import User
from auth import get_password_hash
from config import DEFAULT_ADMIN_USERNAME, DEFAULT_ADMIN_PASSWORD
def init_default_admin():
"""创建默认管理员账号(如果不存在)"""
with get_session() as s:
admin = s.query(User).filter(User.username == DEFAULT_ADMIN_USERNAME).first()
if not admin:
admin = User(
username=DEFAULT_ADMIN_USERNAME,
hashed_password=get_password_hash(DEFAULT_ADMIN_PASSWORD),
is_admin=True,
is_active=True
)
s.add(admin)
s.commit()
print(f"✓ 创建默认管理员: {DEFAULT_ADMIN_USERNAME}")
else:
print(f"✓ 管理员账号已存在: {DEFAULT_ADMIN_USERNAME}")

106
backend/install.sh Normal file
View File

@@ -0,0 +1,106 @@
#!/bin/bash
# 三大核心功能快速安装脚本WSL/Linux
set -e
echo "========================================="
echo " Blackdata StockTerminal 核心功能安装"
echo "========================================="
echo ""
# 检查是否在 WSL/Linux 环境
if [[ "$OSTYPE" != "linux-gnu"* ]]; then
echo "⚠ 此脚本仅支持 WSL/Linux 环境"
exit 1
fi
# 1. 安装系统依赖
echo "[1/6] 检查并安装系统依赖..."
sudo apt update
sudo apt install -y postgresql postgresql-contrib redis-server python3-pip python3-venv
# 2. 启动服务
echo ""
echo "[2/6] 启动 PostgreSQL 和 Redis..."
sudo service postgresql start
sudo service redis-server start
# 验证服务
if redis-cli ping > /dev/null 2>&1; then
echo "✓ Redis 运行正常"
else
echo "⚠ Redis 启动失败,缓存将降级到内存模式"
fi
# 3. 创建虚拟环境(如果不存在)
if [ ! -d ".venv" ]; then
echo ""
echo "[3/6] 创建 Python 虚拟环境..."
python3 -m venv .venv
else
echo ""
echo "[3/6] 虚拟环境已存在,跳过创建"
fi
# 4. 安装 Python 依赖
echo ""
echo "[4/6] 安装 Python 依赖包..."
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
# 5. 配置环境变量
echo ""
echo "[5/6] 配置环境变量..."
if [ ! -f ".env" ]; then
if [ -f ".env.example" ]; then
cp .env.example .env
echo "✓ 已从 .env.example 创建 .env 文件"
else
echo "⚠ .env.example 不存在,请手动创建 .env 文件"
fi
# 生成随机 SECRET_KEY
SECRET_KEY=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))")
echo ""
echo "生成的 SECRET_KEY请添加到 .env:"
echo "SECRET_KEY=$SECRET_KEY"
echo ""
else
echo "✓ .env 文件已存在"
fi
# 6. 初始化数据库
echo ""
echo "[6/6] 初始化数据库..."
# 检查 PostgreSQL 密码配置
if grep -q "PG_PASSWORD=your_password" .env 2>/dev/null || grep -q "PG_PASSWORD=$" .env 2>/dev/null; then
echo ""
echo "⚠ 请先在 .env 中设置 PostgreSQL 密码:"
echo " 1. 设置数据库密码: sudo -u postgres psql -c \"ALTER USER postgres PASSWORD 'your_password';\""
echo " 2. 在 .env 中配置: PG_PASSWORD=your_password"
echo ""
echo "配置完成后,运行: python cli.py init"
else
python cli.py init
echo "✓ 数据库初始化完成"
fi
echo ""
echo "========================================="
echo " 安装完成!"
echo "========================================="
echo ""
echo "下一步:"
echo "1. 编辑 backend/.env 文件,配置数据库密码和其他选项"
echo "2. 如果未初始化数据库,运行: python cli.py init"
echo "3. 启动服务: python main.py"
echo "4. 浏览器访问: http://localhost:8000"
echo "5. 默认管理员: admin / admin123 (首次登录后务必修改密码)"
echo "6. 测试功能: python test_core_features.py"
echo ""
echo "详细文档:"
echo "- 升级指南: backend/UPGRADE_GUIDE.md"
echo "- 配置说明: backend/ENV_CONFIG.md"
echo ""

View File

@@ -1,10 +1,4 @@
"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行
功能:
1. 连板股追踪器
2. 炸板率统计
3. 涨停敢死队排行
"""
"""涨跌停分析 — 连板股追踪、炸板率统计、敢死队排行、炸板走势统计、涨停原因分类。"""
import datetime as dt
from typing import List, Dict, Any, Optional
from collections import defaultdict, Counter
@@ -12,7 +6,24 @@ import numpy as np
from sqlalchemy import select, and_, func, desc
from db import get_session
from models import DailyQuote, StockMetric
from models import DailyQuote, StockMetric, DragonTiger
try:
import akshare as ak
AK_OK = True
except Exception:
ak = None
AK_OK = False
# 涨停原因关键词分类
LIMIT_REASON_MAP = {
"题材": ["概念", "题材", "热点", "风口", "赛道"],
"业绩": ["业绩", "净利润", "营收", "盈利", "超预期", "预增", "扭亏"],
"政策": ["政策", "补贴", "利好", "支持", "规划", "国家", "工信部", "发改委"],
"技术突破": ["突破", "新高", "均线", "金叉", "放量"],
"重组并购": ["重组", "并购", "收购", "合并", "入股"],
"情绪": ["跟风", "连板", "情绪", "氛围", "涨停潮"],
}
def get_limit_stocks(date: Optional[dt.date] = None, limit_type: str = "up") -> Dict[str, Any]:
@@ -293,6 +304,214 @@ def analyze_limit_break_rate(days: int = 60) -> Dict[str, Any]:
}
def get_consecutive_calendar(days: int = 60) -> Dict[str, Any]:
"""连板日历:记录每只股票的连板历史,分析几进几出规律"""
with get_session() as s:
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无数据"}
start_date = latest_date - dt.timedelta(days=days)
quotes = s.execute(
select(DailyQuote)
.where(DailyQuote.date >= start_date)
.order_by(DailyQuote.code, DailyQuote.date)
).scalars().all()
stock_data = defaultdict(list)
for q in quotes:
stock_data[q.code].append(q)
all_streaks = []
current_streaks = []
for code, data in stock_data.items():
data_sorted = sorted(data, key=lambda x: x.date)
name = data_sorted[-1].name
streaks = []
current = []
for q in data_sorted:
if q.open == 0:
if len(current) >= 2:
streaks.append(current)
current = []
continue
pct = (float(q.close) - float(q.open)) / float(q.open) * 100
if pct >= 9.8:
current.append(q)
else:
if len(current) >= 2:
streaks.append(current)
current = []
if len(current) >= 2:
current_streaks.append({
"code": code, "name": name,
"days": len(current),
"start_date": current[0].date.isoformat(),
"latest_date": current[-1].date.isoformat(),
"latest_close": float(current[-1].close)
})
for streak in streaks:
all_streaks.append({
"code": code, "name": name,
"days": len(streak),
"start_date": streak[0].date.isoformat(),
"end_date": streak[-1].date.isoformat()
})
distribution = defaultdict(int)
for item in all_streaks:
distribution[f"{item['days']}"] += 1
current_streaks.sort(key=lambda x: x["days"], reverse=True)
return {
"ok": True,
"date_range": f"{start_date.isoformat()} ~ {latest_date.isoformat()}",
"current_streaks": current_streaks[:30],
"streak_distribution": dict(distribution),
"total_streaks": len(all_streaks)
}
def analyze_post_break_performance(days: int = 90) -> Dict[str, Any]:
"""炸板后走势统计:炸板后 1/3/5 日表现概率分布"""
with get_session() as s:
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
if not latest_date:
return {"ok": False, "msg": "暂无数据"}
start_date = latest_date - dt.timedelta(days=days + 10)
quotes = s.execute(
select(DailyQuote)
.where(DailyQuote.date >= start_date)
.order_by(DailyQuote.code, DailyQuote.date)
).scalars().all()
stock_data = defaultdict(dict)
for q in quotes:
stock_data[q.code][q.date] = q
# 炸板 = 当日一度涨停但收盘时未涨停(用开高低收近似判断)
# 简化:前一日涨停,次日收盘未涨停视为炸板
perf_1d, perf_3d, perf_5d = [], [], []
for code, date_data in stock_data.items():
dates = sorted(date_data.keys())
for i in range(len(dates) - 5):
today = dates[i]
q_today = date_data[today]
if q_today.open == 0:
continue
pct_today = (float(q_today.close) - float(q_today.open)) / float(q_today.open) * 100
# 判断昨日涨停今日炸板(开高但收盘低)
if i == 0:
continue
yesterday = dates[i - 1]
q_yest = date_data[yesterday]
if q_yest.open == 0:
continue
pct_yest = (float(q_yest.close) - float(q_yest.open)) / float(q_yest.open) * 100
# 昨日涨停,今日未涨停(炸板)
if pct_yest >= 9.8 and pct_today < 9.8:
base = float(q_today.close)
# 后续 1/3/5 日表现
for horizon, perf_list in [(1, perf_1d), (3, perf_3d), (5, perf_5d)]:
if i + horizon < len(dates):
future = dates[i + horizon]
q_future = date_data[future]
ret = (float(q_future.close) - base) / base * 100
perf_list.append(round(ret, 2))
def summarize(perfs):
if not perfs:
return {}
arr = np.array(perfs)
return {
"samples": len(perfs),
"avg_ret": round(float(arr.mean()), 2),
"win_rate": round(float((arr > 0).mean() * 100), 1),
"p25": round(float(np.percentile(arr, 25)), 2),
"median": round(float(np.median(arr)), 2),
"p75": round(float(np.percentile(arr, 75)), 2),
}
return {
"ok": True,
"days": days,
"after_1d": summarize(perf_1d),
"after_3d": summarize(perf_3d),
"after_5d": summarize(perf_5d),
"conclusion": (
f"炸板后样本 {len(perf_1d)} 条,"
f"次日平均收益 {summarize(perf_1d).get('avg_ret', 0)}%"
f"次日上涨概率 {summarize(perf_1d).get('win_rate', 0)}%"
) if perf_1d else "样本不足"
}
def classify_limit_reasons(date: Optional[dt.date] = None) -> Dict[str, Any]:
"""涨停原因分类:情绪、题材、业绩、技术突破等"""
with get_session() as s:
if date is None:
date = s.execute(select(func.max(DragonTiger.date))).scalar()
if not date:
return {"ok": False, "msg": "暂无龙虎榜数据,请先入库"}
lhb_rows = s.execute(
select(DragonTiger).where(DragonTiger.date == date)
).scalars().all()
# 尝试从 AkShare 获取当日涨停原因
reason_data = {}
if AK_OK:
try:
df = ak.stock_zt_pool_em(date=date.strftime("%Y%m%d"))
if df is not None and not df.empty:
for _, r in df.iterrows():
code = str(r.get("代码", ""))
reason = str(r.get("涨停原因类别", "") or r.get("上榜原因", ""))
reason_data[code] = reason
except Exception:
pass
# 合并龙虎榜原因
for row in lhb_rows:
if row.code not in reason_data:
reason_data[row.code] = row.reason
# 分类
classified = defaultdict(list)
for code, reason in reason_data.items():
matched = False
for category, keywords in LIMIT_REASON_MAP.items():
if any(kw in reason for kw in keywords):
classified[category].append({"code": code, "reason": reason})
matched = True
break
if not matched:
classified["其他"].append({"code": code, "reason": reason})
total = sum(len(v) for v in classified.values())
summary = [
{
"category": cat,
"count": len(items),
"pct": round(len(items) / total * 100, 1) if total > 0 else 0,
"stocks": items[:10]
}
for cat, items in sorted(classified.items(), key=lambda x: len(x[1]), reverse=True)
]
return {
"ok": True,
"date": date.isoformat(),
"total": total,
"categories": summary
}
def get_limit_squad_rankings(days: int = 30, min_limits: int = 5) -> Dict[str, Any]:
"""涨停敢死队排行

View File

@@ -9,14 +9,30 @@ import datetime as dt
from contextlib import asynccontextmanager
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, Query
from fastapi import FastAPI, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from sqlalchemy.exc import SQLAlchemyError
from fastapi.staticfiles import StaticFiles
from sqlalchemy import select, func, desc
from pydantic import BaseModel
import akshare_service as svc
import redis_cache
from redis_cache import cache
import auth
from auth import get_current_user, require_auth, require_admin
import init_auth
import exceptions
from exceptions import (
BusinessException,
DataSourceException,
business_exception_handler,
validation_exception_handler,
sqlalchemy_exception_handler,
general_exception_handler
)
import config
import scheduler
import backtest as bt
@@ -37,6 +53,11 @@ import sentiment_monitor as sentiment
import event_driven as events
import financial_analysis as fin
import limit_analysis as limit_up
import watchlist_manager as wl
import position_cost as pc
import trade_calendar as cal
import data_manager as dm
import paper_trading as paper
from db import init_db, get_session
from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
SentimentDaily, DragonTiger, Security, JobRun, StockMetric, Trade,
@@ -47,8 +68,11 @@ from models import (DailyQuote, IndexDaily, SectorDaily, FundFlowDaily,
async def lifespan(app: FastAPI):
try:
init_db()
init_auth.init_default_admin()
wl.init_default_groups()
paper.ensure_default_account()
scheduler.start_scheduler()
print("[startup] db + scheduler ready")
print("[startup] db + scheduler + auth ready")
except Exception as e:
print("[startup] WARN:", repr(e)[:160])
yield
@@ -56,6 +80,12 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Blackdata股票终端 API", version="0.2.0", lifespan=lifespan)
# 注册异常处理器
app.add_exception_handler(BusinessException, business_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler)
app.add_exception_handler(Exception, general_exception_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -84,10 +114,113 @@ def save_watch(symbols):
json.dump(symbols, f, ensure_ascii=False)
# ============ 认证相关 API ============
class LoginRequest(BaseModel):
username: str
password: str
@app.post("/api/auth/login")
def login(req: LoginRequest, db = Depends(get_session)):
"""用户登录"""
user = auth.authenticate_user(db, req.username, req.password)
if not user:
raise exceptions.AuthException("用户名或密码错误")
access_token = auth.create_access_token(data={"sub": user.username})
return {
"ok": True,
"access_token": access_token,
"token_type": "bearer",
"username": user.username,
"is_admin": user.is_admin
}
@app.get("/api/auth/me")
def get_me(current_user = Depends(require_auth)):
"""获取当前用户信息"""
return {
"ok": True,
"username": current_user.username,
"is_admin": current_user.is_admin
}
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
@app.post("/api/auth/change-password")
def change_password(req: ChangePasswordRequest, current_user = Depends(require_auth), db = Depends(get_session)):
"""修改密码"""
if not auth.verify_password(req.old_password, current_user.hashed_password):
raise exceptions.AuthException("原密码错误")
current_user.hashed_password = auth.get_password_hash(req.new_password)
db.commit()
return {"ok": True, "msg": "密码修改成功"}
# ============ 用户管理 ============
class CreateUserRequest(BaseModel):
username: str
password: str
is_admin: bool = False
from models import User as UserModel
@app.get("/api/users")
def list_users(current_user = Depends(require_admin)):
with get_session() as s:
rows = s.execute(select(UserModel).order_by(UserModel.id)).scalars().all()
return {"ok": True, "users": [{"id": r.id, "username": r.username,
"is_admin": r.is_admin, "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]}
@app.post("/api/users")
def create_user(req: CreateUserRequest, current_user = Depends(require_admin)):
with get_session() as s:
if s.execute(select(UserModel).where(UserModel.username == req.username)).scalar_one_or_none():
return {"ok": False, "msg": "用户名已存在"}
user = UserModel(username=req.username,
hashed_password=auth.get_password_hash(req.password), is_admin=req.is_admin)
s.add(user); s.commit()
return {"ok": True, "id": user.id}
@app.delete("/api/users/{uid}")
def delete_user(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能删除自己"}
with get_session() as s:
u = s.get(UserModel, uid)
if u: s.delete(u); s.commit()
return {"ok": True}
@app.put("/api/users/{uid}/toggle_admin")
def toggle_admin(uid: int, current_user = Depends(require_admin)):
if current_user.id == uid:
return {"ok": False, "msg": "不能修改自己的权限"}
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.is_admin = not u.is_admin; s.commit()
return {"ok": True, "is_admin": u.is_admin}
@app.put("/api/users/{uid}/reset_password")
def reset_password(uid: int, req: ChangePasswordRequest, current_user = Depends(require_admin)):
with get_session() as s:
u = s.get(UserModel, uid)
if not u: return {"ok": False, "msg": "用户不存在"}
u.hashed_password = auth.get_password_hash(req.new_password); s.commit()
return {"ok": True}
# ============ API ============
@app.get("/api/health")
def health():
return {"ok": True, "akshare": svc.AK_OK}
return {
"ok": True,
"akshare": svc.AK_OK,
"redis": cache.enabled,
"auth": True
}
@app.get("/api/indices")
@@ -106,10 +239,76 @@ def sentiment():
@app.get("/api/treemap")
def treemap(mode: str = Query("sector")):
def treemap(mode: str = Query("sector"), date: str = Query(None)):
if mode == "sector" and date:
# 从数据库读历史板块数据
try:
target = dt.date.fromisoformat(date)
except Exception:
return svc.get_treemap(mode)
with get_session() as s:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == target)
.order_by(SectorDaily.pct.desc())
).scalars().all()
if not rows:
# 找最近有数据的日期
latest = s.execute(select(func.max(SectorDaily.date))).scalar()
if latest:
rows = s.execute(
select(SectorDaily).where(SectorDaily.date == latest)
.order_by(SectorDaily.pct.desc())
).scalars().all()
target = latest
if rows:
items = [{"name": r.name, "value": r.amount or 1, "pct": round(r.pct, 2)} for r in rows]
return {"source": "db", "mode": "sector", "date": target.isoformat(), "items": items}
return svc.get_treemap(mode)
@app.get("/api/treemap/us")
def treemap_us():
return svc.get_us_treemap()
@app.get("/api/treemap/hk")
def treemap_hk():
return svc.get_hk_treemap()
@app.get("/api/treemap/sector_stocks")
def sector_stocks(name: str = Query(...), limit: int = Query(20, ge=5, le=100)):
return svc.get_sector_stocks(name, limit)
@app.get("/api/treemap/all_leaders")
def all_sector_leaders(top_n: int = Query(5, ge=3, le=10), date: str = Query(None)):
from models import SectorLeader
# 优先从数据库读
with get_session() as s:
target = None
if date:
try: target = dt.date.fromisoformat(date)
except Exception: pass
if not target:
target = s.execute(select(func.max(SectorLeader.date))).scalar()
if target:
rows = s.execute(
select(SectorLeader).where(SectorLeader.date == target)
.order_by(SectorLeader.sector, SectorLeader.rank)
).scalars().all()
if rows:
sectors = {}
for r in rows:
sectors.setdefault(r.sector, []).append({
"code": r.code, "name": r.name, "pct": r.pct,
"price": r.price, "amount": r.amount
})
return {"source": "db", "date": target.isoformat(), "sectors": sectors}
# 降级到实时
return svc.get_all_sector_leaders(top_n)
@app.get("/api/fundflow")
def fundflow():
return svc.get_fund_flow()
@@ -151,9 +350,105 @@ def watch_del(code: str):
return {"ok": True, "list": w}
# ============ 自选股分组管理 ============
class CreateGroupRequest(BaseModel):
name: str
description: str = ""
color: str = "blue"
@app.get("/api/watchlist/groups")
def list_groups():
"""获取所有分组"""
return {"ok": True, "groups": wl.get_all_groups()}
@app.post("/api/watchlist/groups")
def create_group(req: CreateGroupRequest):
"""创建新分组"""
return wl.create_group(req.name, req.description, req.color)
class UpdateGroupRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
@app.put("/api/watchlist/groups/{group_id}")
def update_group(group_id: int, req: UpdateGroupRequest):
"""更新分组信息"""
return wl.update_group(group_id, req.name, req.description, req.color)
@app.delete("/api/watchlist/groups/{group_id}")
def delete_group(group_id: int):
"""删除分组"""
return wl.delete_group(group_id)
class ReorderGroupsRequest(BaseModel):
group_ids: List[int]
@app.post("/api/watchlist/groups/reorder")
def reorder_groups(req: ReorderGroupsRequest):
"""重新排序分组"""
return wl.reorder_groups(req.group_ids)
@app.get("/api/watchlist/groups/{group_id}/stocks")
def get_group_stocks(group_id: int, with_quotes: bool = Query(True)):
"""获取分组内的股票"""
return wl.get_group_stocks(group_id, with_quotes)
class AddStockRequest(BaseModel):
code: str
note: str = ""
@app.post("/api/watchlist/groups/{group_id}/stocks")
def add_stock_to_group(group_id: int, req: AddStockRequest):
"""添加股票到分组"""
return wl.add_stock_to_group(group_id, req.code, req.note)
@app.delete("/api/watchlist/stocks/{item_id}")
def remove_stock(item_id: int):
"""从分组中移除股票"""
return wl.remove_stock_from_group(item_id)
class MoveStockRequest(BaseModel):
target_group_id: int
@app.post("/api/watchlist/stocks/{item_id}/move")
def move_stock(item_id: int, req: MoveStockRequest):
"""移动股票到另一个分组"""
return wl.move_stock_to_group(item_id, req.target_group_id)
class BatchAddRequest(BaseModel):
codes: List[str]
@app.post("/api/watchlist/groups/{group_id}/stocks/batch")
def batch_add_stocks(group_id: int, req: BatchAddRequest):
"""批量添加股票"""
return wl.batch_add_stocks(group_id, req.codes)
class UpdateNoteRequest(BaseModel):
note: str
@app.put("/api/watchlist/stocks/{item_id}/note")
def update_stock_note(item_id: int, req: UpdateNoteRequest):
"""更新股票备注"""
return wl.update_stock_note(item_id, req.note)
class ReorderStocksRequest(BaseModel):
item_ids: List[int]
@app.post("/api/watchlist/stocks/reorder")
def reorder_stocks(req: ReorderStocksRequest):
"""重新排序股票"""
return wl.reorder_stocks(req.item_ids)
@app.get("/api/watchlist/search")
def search_stocks(keyword: str = Query(..., min_length=1)):
"""跨分组搜索股票"""
return {"ok": True, "results": wl.search_stocks_across_groups(keyword)}
# ============ 数据中台 ============
@app.get("/api/admin/status")
def admin_status():
def admin_status(current_user = Depends(require_admin)):
counts, last_dates = {}, {}
with get_session() as s:
for label, model in [("securities", Security), ("quotes_daily", DailyQuote),
@@ -175,14 +470,14 @@ def admin_status():
@app.post("/api/admin/ingest")
def admin_ingest():
def admin_ingest(current_user = Depends(require_admin)):
if scheduler.is_running():
return {"started": False, "msg": "已有入库任务在执行"}
return scheduler.trigger_async()
@app.post("/api/admin/ingest_all")
def admin_ingest_all():
def admin_ingest_all(current_user = Depends(require_admin)):
return scheduler.trigger_all_async()
@@ -488,6 +783,16 @@ def ai_today():
return ai.today_strategy()
@app.get("/api/ai/trend_analysis")
def ai_trend_analysis(
symbol: str = Query(...),
date: str = Query(""),
period: str = Query("daily")
):
"""走势分析右键K线条形时调用分析暴涨/暴跌原因"""
return ai.trend_analysis(symbol, date, period)
# ============ 可回溯:信号历史胜率 + 实测准确率 ============
@app.get("/api/ai/signal_stats")
def ai_signal_stats(horizon: int = Query(5, ge=1, le=20)):
@@ -582,6 +887,144 @@ def portfolio_equity():
return pf.equity_curve()
# ============ 持仓成本可视化增强 ============
@app.get("/api/portfolio/cost_line/{code}")
def get_cost_line(code: str):
"""获取个股持仓成本线用于K线图标注"""
return pc.get_position_cost_lines(code)
@app.get("/api/portfolio/cost_distribution")
def get_cost_distribution():
"""获取持仓成本分布(盈亏区间图)"""
return pc.get_position_cost_distribution()
class EstimateCostRequest(BaseModel):
code: str
price: float
qty: int
side: str = "buy"
@app.post("/api/portfolio/estimate_cost")
def estimate_cost(req: EstimateCostRequest):
"""估算交易成本(下单前预估)"""
return pc.estimate_trade_cost(req.code, req.price, req.qty, req.side)
@app.get("/api/portfolio/cost_breakdown/{code}")
def get_cost_breakdown(code: str):
"""获取持仓的详细成本拆解"""
return pc.get_cost_breakdown_for_position(code)
# ============ 交易日历与关键事件 ============
@app.get("/api/calendar/events")
def calendar_events(days: int = Query(30, ge=7, le=90)):
"""获取所有即将到来的关键事件(综合视图)"""
return cal.get_all_upcoming_events(days)
@app.get("/api/calendar/dividends")
def calendar_dividends(days: int = Query(30, ge=7, le=90)):
"""除权除息日历(持仓股优先)"""
return cal.get_upcoming_dividends(days)
@app.get("/api/calendar/unlock")
def calendar_unlock(days: int = Query(90, ge=7, le=180)):
"""限售解禁日历"""
return cal.get_unlock_calendar(days)
@app.get("/api/calendar/earnings")
def calendar_earnings(days: int = Query(30, ge=7, le=60), holding_only: bool = Query(False)):
"""财报披露日历"""
return cal.get_earnings_calendar(days, holding_only)
@app.post("/api/calendar/check_alerts")
def calendar_check_alerts(current_user = Depends(require_admin)):
"""手动触发日历事件预警推送"""
return cal.check_and_push_calendar_alerts()
# ============ 数据修正与回填增强 ============
class UpdateQuoteRequest(BaseModel):
open: Optional[float] = None
high: Optional[float] = None
low: Optional[float] = None
close: Optional[float] = None
volume: Optional[int] = None
amount: Optional[float] = None
@app.delete("/api/data/quote/{code}/{date}")
def delete_quote(code: str, date: str, current_user = Depends(require_admin)):
"""删除指定股票指定日期的日线"""
return dm.delete_quote(code, date)
@app.put("/api/data/quote/{code}/{date}")
def update_quote(code: str, date: str, req: UpdateQuoteRequest,
current_user = Depends(require_admin)):
"""修正指定日线数据"""
return dm.update_quote(code, date, req.model_dump(exclude_none=True))
class DeleteRangeRequest(BaseModel):
start: str
end: str
@app.delete("/api/data/quotes/{code}/range")
def delete_quotes_range(code: str, req: DeleteRangeRequest,
current_user = Depends(require_admin)):
"""删除指定股票日期范围内的日线数据"""
return dm.delete_quotes_range(code, req.start, req.end)
@app.post("/api/data/refetch/{code}")
def refetch_quote(code: str, days: int = Query(60, ge=5, le=500),
current_user = Depends(require_admin)):
"""重新抓取指定股票日线(覆盖更新)"""
return dm.refetch_quote(code, days)
@app.get("/api/data/integrity")
def check_integrity(days: int = Query(30, ge=7, le=90),
current_user = Depends(require_admin)):
"""数据完整性检查"""
return dm.check_data_integrity(days=days)
@app.post("/api/data/auto_fix")
def auto_fix_missing(limit: int = Query(50, ge=10, le=200),
current_user = Depends(require_admin)):
"""自动补齐缺失数据"""
t = __import__("threading").Thread(
target=dm.auto_fix_missing, kwargs={"limit": limit}, daemon=True
)
t.start()
return {"ok": True, "msg": "已启动自动修复任务,请在数据中台查看进度"}
@app.get("/api/data/refill_progress")
def refill_progress(task_id: str = Query("default")):
"""获取回填进度"""
return dm.get_refill_progress(task_id)
@app.post("/api/data/refill_resume")
def refill_resume(days: int = Query(250, ge=30, le=1000),
task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""带断点续传的全市场回填(后台执行)"""
import threading
t = threading.Thread(
target=dm.start_refill_with_resume,
kwargs={"days": days, "task_id": task_id},
daemon=True
)
t.start()
return {"ok": True, "msg": f"已启动断点续传回填,天数={days}任务ID={task_id}"}
@app.delete("/api/data/refill_progress")
def clear_refill_progress(task_id: str = Query("default"),
current_user = Depends(require_admin)):
"""清除回填进度(从头开始)"""
return dm.clear_refill_progress(task_id)
@app.get("/api/data/quality_report")
def data_quality_report(current_user = Depends(require_admin)):
"""数据质量报告"""
return dm.get_data_quality_report()
@app.get("/api/portfolio/attribution")
def portfolio_attribution():
"""持仓归因分析"""
@@ -761,6 +1204,22 @@ def limit_squad(days: int = Query(30, ge=10, le=90), min_limits: int = Query(5,
"""涨停敢死队排行"""
return limit_up.get_limit_squad_rankings(days, min_limits)
@app.get("/api/limit/consecutive_calendar")
def consecutive_calendar(days: int = Query(60, ge=20, le=120)):
"""连板日历:记录连板历史,分析几进几出规律"""
return limit_up.get_consecutive_calendar(days)
@app.get("/api/limit/post_break")
def post_break_performance(days: int = Query(90, ge=30, le=180)):
"""炸板后 1/3/5 日走势统计"""
return limit_up.analyze_post_break_performance(days)
@app.get("/api/limit/reasons")
def limit_reasons(date: Optional[str] = None):
"""涨停原因分类(情绪/题材/业绩/政策等)"""
d = dt.date.fromisoformat(date) if date else None
return limit_up.classify_limit_reasons(d)
# ============ 推送通知 ============
@app.get("/api/notify/status")
@@ -1142,6 +1601,51 @@ def delete_selector_alert(aid: int):
return {"ok": True}
# ============ 模拟盘 ============
class PaperAccountIn(BaseModel):
name: str
initial_cash: float = 1_000_000.0
@app.get("/api/paper/accounts")
def paper_list_accounts():
return {"ok": True, "accounts": paper.list_accounts()}
@app.post("/api/paper/accounts")
def paper_create_account(req: PaperAccountIn):
return paper.create_account(req.name, req.initial_cash)
@app.post("/api/paper/accounts/{account_id}/reset")
def paper_reset_account(account_id: int, initial_cash: Optional[float] = None):
return paper.reset_account(account_id, initial_cash)
class PaperOrderIn(BaseModel):
code: str
side: str # buy / sell
qty: int
price: Optional[float] = None
reason: str = ""
@app.post("/api/paper/accounts/{account_id}/order")
def paper_place_order(account_id: int, req: PaperOrderIn):
return paper.place_order(account_id, req.code, req.side, req.qty, req.price, req.reason)
@app.get("/api/paper/accounts/{account_id}/portfolio")
def paper_get_portfolio(account_id: int):
return paper.get_portfolio(account_id)
@app.get("/api/paper/accounts/{account_id}/trades")
def paper_get_trades(account_id: int, limit: int = Query(100, le=500)):
return {"ok": True, "trades": paper.get_trades(account_id, limit)}
# ============ 静态前端 ============
FRONTEND_DIR = os.path.join(os.path.dirname(BASE_DIR), "prototype")
if os.path.isdir(FRONTEND_DIR):

View File

@@ -1,4 +1,4 @@
"""数据中台 ORM 模型SQLAlchemy 2.0)。"""
"""数据中台 ORM 模型SQLAlchemy 2.0)。"""
from __future__ import annotations
import datetime as dt
@@ -64,6 +64,21 @@ class SectorDaily(Base):
leader: Mapped[str] = mapped_column(String(40), default="")
class SectorLeader(Base):
"""板块每日龙头股前5按成交额"""
__tablename__ = "sector_leaders"
__table_args__ = (UniqueConstraint("date", "sector", "code", name="uq_sector_leader"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
date: Mapped[dt.date] = mapped_column(Date, index=True)
sector: Mapped[str] = mapped_column(String(40), index=True)
code: Mapped[str] = mapped_column(String(12))
name: Mapped[str] = mapped_column(String(40))
pct: Mapped[float] = mapped_column(Float, default=0.0)
price: Mapped[float] = mapped_column(Float, default=0.0)
amount: Mapped[float] = mapped_column(Float, default=0.0)
rank: Mapped[int] = mapped_column(Integer, default=0)
class FundFlowDaily(Base):
"""行业资金流每日快照。"""
__tablename__ = "fund_flow_daily"
@@ -349,3 +364,68 @@ class IntradayEvent(Base):
description: Mapped[str] = mapped_column(String(200), default="")
detected_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now(), index=True)
notified: Mapped[bool] = mapped_column(default=False)
class PaperAccount(Base):
"""模拟盘账户。"""
__tablename__ = "paper_accounts"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), default="默认模拟盘")
initial_cash: Mapped[float] = mapped_column(Float, default=1_000_000.0)
cash: Mapped[float] = mapped_column(Float, default=1_000_000.0)
is_active: Mapped[bool] = mapped_column(default=True)
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
class PaperTrade(Base):
"""模拟盘交易记录。"""
__tablename__ = "paper_trades"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
account_id: Mapped[int] = mapped_column(Integer, index=True, default=1)
date: Mapped[dt.date] = mapped_column(Date, index=True)
code: Mapped[str] = mapped_column(String(12), index=True)
name: Mapped[str] = mapped_column(String(40), default="")
side: Mapped[str] = mapped_column(String(4)) # buy / sell
price: Mapped[float] = mapped_column(Float)
qty: Mapped[int] = mapped_column(Integer)
fee: Mapped[float] = mapped_column(Float, default=0.0)
cash_before: Mapped[float] = mapped_column(Float, default=0.0)
cash_after: Mapped[float] = mapped_column(Float, default=0.0)
reason: Mapped[str] = mapped_column(String(60), default="")
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
class User(Base):
"""用户表(用于鉴权)。"""
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
hashed_password: Mapped[str] = mapped_column(String(100))
is_admin: Mapped[bool] = mapped_column(default=False)
is_active: Mapped[bool] = mapped_column(default=True)
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
class WatchlistGroup(Base):
"""自选股分组。"""
__tablename__ = "watchlist_groups"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50))
description: Mapped[str] = mapped_column(String(200), default="")
color: Mapped[str] = mapped_column(String(20), default="blue") # 分组颜色标识
sort_order: Mapped[int] = mapped_column(Integer, default=0)
is_default: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())
class WatchlistItem(Base):
"""自选股项目。"""
__tablename__ = "watchlist_items"
__table_args__ = (UniqueConstraint("group_id", "code", name="uq_watchlist_group_code"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
group_id: Mapped[int] = mapped_column(Integer, index=True)
code: Mapped[str] = mapped_column(String(12), index=True)
name: Mapped[str] = mapped_column(String(40), default="")
sort_order: Mapped[int] = mapped_column(Integer, default=0)
note: Mapped[str] = mapped_column(String(200), default="") # 个股备注
added_at: Mapped[dt.datetime] = mapped_column(DateTime, server_default=func.now())

222
backend/paper_trading.py Normal file
View File

@@ -0,0 +1,222 @@
"""模拟盘核心逻辑。
- 多账户支持(默认账户 id=1
- 买卖按实时价(或收盘价)撮合,自动扣减/增加现金
- 持仓计算使用移动加权平均成本法
"""
from __future__ import annotations
import datetime as dt
from collections import defaultdict
from sqlalchemy import select
from db import get_session
from models import PaperAccount, PaperTrade, Security, StockMetric, DailyQuote
DEFAULT_FEE_RATE = 0.0003
def _get_price(code: str) -> float | None:
with get_session() as s:
m = s.execute(select(StockMetric.close).where(StockMetric.code == code)).scalar_one_or_none()
if m:
return float(m)
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == code)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
return float(row) if row else None
# ── 账户管理 ──────────────────────────────────────────────
def ensure_default_account():
"""确保默认账户id=1存在启动时调用。"""
with get_session() as s:
if not s.get(PaperAccount, 1):
s.add(PaperAccount(name="默认模拟盘", initial_cash=1_000_000.0, cash=1_000_000.0))
s.commit()
def list_accounts() -> list[dict]:
with get_session() as s:
rows = s.execute(select(PaperAccount).order_by(PaperAccount.id)).scalars().all()
return [{"id": r.id, "name": r.name, "initial_cash": r.initial_cash,
"cash": round(r.cash, 2), "is_active": r.is_active,
"created_at": r.created_at.strftime("%Y-%m-%d")} for r in rows]
def create_account(name: str, initial_cash: float) -> dict:
with get_session() as s:
acc = PaperAccount(name=name, initial_cash=initial_cash, cash=initial_cash)
s.add(acc)
s.commit()
return {"ok": True, "id": acc.id}
def reset_account(account_id: int, initial_cash: float | None = None) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
if initial_cash is not None:
acc.initial_cash = initial_cash
acc.cash = acc.initial_cash
for t in s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
).scalars():
s.delete(t)
s.commit()
return {"ok": True, "msg": "账户已重置"}
# ── 持仓计算(内部)────────────────────────────────────────
def _calc_holdings_in_session(account_id: int, s) -> list[dict]:
trades = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.date, PaperTrade.id)
).scalars().all()
pos: dict = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
for t in trades:
p = pos[t.code]
p["name"] = t.name or p["name"]
if t.side == "buy":
p["cost"] += t.price * t.qty + t.fee
p["qty"] += t.qty
else:
if p["qty"] > 0:
avg = p["cost"] / p["qty"]
qty = min(t.qty, p["qty"])
p["cost"] -= avg * qty
p["qty"] -= qty
return [{"code": c, "name": v["name"], "qty": v["qty"], "cost": v["cost"]}
for c, v in pos.items() if v["qty"] > 0]
# ── 下单 ──────────────────────────────────────────────────
def place_order(account_id: int, code: str, side: str, qty: int,
price: float | None = None, reason: str = "") -> dict:
if qty <= 0:
return {"ok": False, "msg": "数量必须大于 0"}
if side not in ("buy", "sell"):
return {"ok": False, "msg": "side 只能是 buy 或 sell"}
exec_price = price or _get_price(code)
if not exec_price:
return {"ok": False, "msg": f"无法获取 {code} 的价格,请手动传入 price"}
fee = round(exec_price * qty * DEFAULT_FEE_RATE, 2)
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
sec = s.get(Security, code)
name = sec.name if sec else code
if side == "buy":
cost = exec_price * qty + fee
if acc.cash < cost:
return {"ok": False, "msg": f"现金不足,需 {cost:.2f},余 {acc.cash:.2f}"}
cash_before = acc.cash
acc.cash -= cost
else:
holdings = _calc_holdings_in_session(account_id, s)
pos = next((h for h in holdings if h["code"] == code), None)
avail = pos["qty"] if pos else 0
if avail < qty:
return {"ok": False, "msg": f"持仓不足,持有 {avail} 股,尝试卖出 {qty}"}
cash_before = acc.cash
acc.cash += exec_price * qty - fee
trade = PaperTrade(
account_id=account_id,
date=dt.date.today(),
code=code, name=name, side=side,
price=exec_price, qty=qty, fee=fee,
cash_before=cash_before, cash_after=acc.cash,
reason=reason,
)
s.add(trade)
s.commit()
return {"ok": True, "id": trade.id, "price": exec_price,
"fee": fee, "cash_after": round(acc.cash, 2)}
# ── 查询接口 ──────────────────────────────────────────────
def get_portfolio(account_id: int) -> dict:
with get_session() as s:
acc = s.get(PaperAccount, account_id)
if not acc:
return {"ok": False, "msg": "账户不存在"}
cash = acc.cash
initial = acc.initial_cash
holdings_raw = _calc_holdings_in_session(account_id, s)
codes = [h["code"] for h in holdings_raw]
px: dict[str, float] = {}
if codes:
with get_session() as s:
for m in s.execute(
select(StockMetric).where(StockMetric.code.in_(codes))
).scalars():
px[m.code] = m.close
for c in [c for c in codes if c not in px]:
row = s.execute(
select(DailyQuote.close).where(DailyQuote.code == c)
.order_by(DailyQuote.date.desc()).limit(1)
).scalar_one_or_none()
if row:
px[c] = float(row)
holdings, mkt_val = [], 0.0
for h in holdings_raw:
avg = h["cost"] / h["qty"] if h["qty"] else 0.0
cur = px.get(h["code"], avg)
mv = cur * h["qty"]
unreal = (cur - avg) * h["qty"]
mkt_val += mv
holdings.append({
"code": h["code"], "name": h["name"], "qty": h["qty"],
"avg_cost": round(avg, 3), "cur": round(cur, 3),
"market_value": round(mv, 2),
"unrealized": round(unreal, 2),
"unrealized_pct": round((cur / avg - 1) * 100, 2) if avg else 0.0,
})
holdings.sort(key=lambda x: x["unrealized"], reverse=True)
total_assets = cash + mkt_val
total_pnl = total_assets - initial
return {
"ok": True,
"account_id": account_id,
"summary": {
"initial_cash": round(initial, 2),
"cash": round(cash, 2),
"market_value": round(mkt_val, 2),
"total_assets": round(total_assets, 2),
"total_pnl": round(total_pnl, 2),
"total_pnl_pct": round(total_pnl / initial * 100, 2) if initial else 0.0,
"positions": len(holdings),
},
"holdings": holdings,
}
def get_trades(account_id: int, limit: int = 100) -> list[dict]:
with get_session() as s:
rows = s.execute(
select(PaperTrade).where(PaperTrade.account_id == account_id)
.order_by(PaperTrade.id.desc()).limit(limit)
).scalars().all()
return [{
"id": t.id, "date": t.date.isoformat(), "code": t.code, "name": t.name,
"side": t.side, "price": t.price, "qty": t.qty, "fee": t.fee,
"cash_before": round(t.cash_before, 2), "cash_after": round(t.cash_after, 2),
"reason": t.reason,
} for t in rows]

347
backend/position_cost.py Normal file
View File

@@ -0,0 +1,347 @@
"""持仓成本可视化增强"""
import datetime as dt
from typing import Dict, List, Optional
from collections import defaultdict
from sqlalchemy import select, func
from db import get_session
from models import Trade, DailyQuote, StockMetric
# A股交易成本配置
COST_CONFIG = {
"stamp_tax": 0.001, # 印花税 0.1%(仅卖出)
"commission_rate": 0.0003, # 佣金费率 0.03%
"commission_min": 5.0, # 最低佣金 5元
"transfer_fee": 0.00001, # 过户费 0.001%(沪市)
}
def calculate_trade_cost(price: float, qty: int, side: str, is_sh: bool = True) -> Dict:
"""精确计算交易成本
Args:
price: 成交价格
qty: 成交数量
side: buy/sell
is_sh: 是否沪市(影响过户费)
Returns:
成本明细字典
"""
amount = price * qty
# 佣金(买卖都有)
commission = max(amount * COST_CONFIG["commission_rate"], COST_CONFIG["commission_min"])
# 印花税(仅卖出)
stamp_tax = amount * COST_CONFIG["stamp_tax"] if side == "sell" else 0.0
# 过户费(沪市买卖都有,深市无)
transfer_fee = amount * COST_CONFIG["transfer_fee"] if is_sh else 0.0
total_cost = commission + stamp_tax + transfer_fee
return {
"amount": round(amount, 2),
"commission": round(commission, 2),
"stamp_tax": round(stamp_tax, 2),
"transfer_fee": round(transfer_fee, 2),
"total_cost": round(total_cost, 2),
"cost_rate": round(total_cost / amount * 100, 4) if amount > 0 else 0.0
}
def get_position_cost_lines(code: str) -> Dict:
"""获取个股的持仓成本线数据用于K线图标注
Returns:
{
"code": "600519",
"name": "贵州茅台",
"current_position": {
"qty": 100,
"avg_cost": 1680.5,
"total_cost": 168050.0,
"trades_count": 3
},
"cost_history": [
{"date": "2024-01-15", "cost": 1650.0, "qty": 100, "action": "买入"},
{"date": "2024-02-10", "cost": 1680.5, "qty": 100, "action": "补仓"}
]
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).where(Trade.code == code)
.order_by(Trade.date, Trade.id)
).scalars().all()
if not trades:
return {"ok": False, "msg": "该股票无交易记录"}
# 计算持仓成本变化
qty = 0
cost = 0.0
cost_history = []
for t in trades:
is_sh = t.code.startswith("6")
if t.side == "buy":
# 买入:加权平均成本
old_qty = qty
old_cost = cost
qty += t.qty
cost += t.price * t.qty + t.fee
avg_cost = cost / qty if qty > 0 else 0
action = "补仓" if old_qty > 0 else "买入"
cost_history.append({
"date": t.date.isoformat(),
"cost": round(avg_cost, 2),
"qty": qty,
"action": action,
"trade_price": t.price,
"trade_qty": t.qty
})
else: # sell
if qty <= 0:
continue
avg_cost = cost / qty
sell_qty = min(t.qty, qty)
# 卖出:减少持仓
cost -= avg_cost * sell_qty
qty -= sell_qty
action = "清仓" if qty == 0 else "减仓"
cost_history.append({
"date": t.date.isoformat(),
"cost": round(cost / qty, 2) if qty > 0 else 0,
"qty": qty,
"action": action,
"trade_price": t.price,
"trade_qty": sell_qty,
"pnl": round((t.price - avg_cost) * sell_qty - t.fee, 2)
})
# 当前持仓
current_position = None
if qty > 0:
avg_cost = cost / qty
# 获取当前价格
metric = s.execute(
select(StockMetric).where(StockMetric.code == code)
).scalar_one_or_none()
current_price = metric.close if metric else avg_cost
current_position = {
"qty": qty,
"avg_cost": round(avg_cost, 2),
"total_cost": round(cost, 2),
"current_price": round(current_price, 2),
"market_value": round(current_price * qty, 2),
"unrealized_pnl": round((current_price - avg_cost) * qty, 2),
"unrealized_pct": round((current_price / avg_cost - 1) * 100, 2) if avg_cost > 0 else 0,
"trades_count": len([t for t in trades if t.side == "buy"])
}
return {
"ok": True,
"code": code,
"name": trades[0].name,
"current_position": current_position,
"cost_history": cost_history
}
def get_position_cost_distribution() -> Dict:
"""获取所有持仓的成本分布(盈亏区间图)
Returns:
{
"profitable": [...], # 盈利持仓
"unprofitable": [...], # 亏损持仓
"breakeven": [...] # 持平持仓
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).order_by(Trade.date, Trade.id)
).scalars().all()
# 计算当前持仓
pos = defaultdict(lambda: {"qty": 0, "cost": 0.0, "name": ""})
for t in trades:
p = pos[t.code]
p["name"] = t.name or p["name"]
if t.side == "buy":
p["cost"] += t.price * t.qty + t.fee
p["qty"] += t.qty
else:
if p["qty"] > 0:
avg = p["cost"] / p["qty"]
qty = min(t.qty, p["qty"])
p["cost"] -= avg * qty
p["qty"] -= qty
# 获取当前价格
codes = [c for c, v in pos.items() if v["qty"] > 0]
if not codes:
return {"ok": True, "profitable": [], "unprofitable": [], "breakeven": []}
metrics = s.execute(
select(StockMetric).where(StockMetric.code.in_(codes))
).scalars().all()
price_map = {m.code: m.close for m in metrics}
# 分类统计
profitable = []
unprofitable = []
breakeven = []
for code, p in pos.items():
if p["qty"] <= 0:
continue
avg_cost = p["cost"] / p["qty"]
current_price = price_map.get(code, avg_cost)
unrealized = (current_price - avg_cost) * p["qty"]
unrealized_pct = (current_price / avg_cost - 1) * 100 if avg_cost > 0 else 0
item = {
"code": code,
"name": p["name"],
"qty": p["qty"],
"avg_cost": round(avg_cost, 2),
"current_price": round(current_price, 2),
"market_value": round(current_price * p["qty"], 2),
"cost_value": round(p["cost"], 2),
"unrealized": round(unrealized, 2),
"unrealized_pct": round(unrealized_pct, 2)
}
if unrealized_pct > 0.5:
profitable.append(item)
elif unrealized_pct < -0.5:
unprofitable.append(item)
else:
breakeven.append(item)
# 排序
profitable.sort(key=lambda x: x["unrealized"], reverse=True)
unprofitable.sort(key=lambda x: x["unrealized"])
return {
"ok": True,
"profitable": profitable,
"unprofitable": unprofitable,
"breakeven": breakeven,
"summary": {
"total_positions": len(codes),
"profitable_count": len(profitable),
"unprofitable_count": len(unprofitable),
"breakeven_count": len(breakeven),
"win_rate": round(len(profitable) / len(codes) * 100, 1) if codes else 0
}
}
def estimate_trade_cost(code: str, price: float, qty: int, side: str) -> Dict:
"""估算交易成本(下单前预估)
Args:
code: 股票代码
price: 预计成交价
qty: 交易数量
side: buy/sell
Returns:
成本明细和净值
"""
is_sh = code.startswith("6")
cost_detail = calculate_trade_cost(price, qty, side, is_sh)
if side == "buy":
net_amount = cost_detail["amount"] + cost_detail["total_cost"]
msg = f"买入需支付: {round(net_amount, 2)} 元(含交易成本 {cost_detail['total_cost']} 元)"
else:
net_amount = cost_detail["amount"] - cost_detail["total_cost"]
msg = f"卖出可获得: {round(net_amount, 2)} 元(扣除交易成本 {cost_detail['total_cost']} 元)"
return {
"ok": True,
"code": code,
"price": price,
"qty": qty,
"side": side,
"cost_detail": cost_detail,
"net_amount": round(net_amount, 2),
"message": msg
}
def get_cost_breakdown_for_position(code: str) -> Dict:
"""获取持仓的详细成本拆解
Returns:
{
"total_cost": 168500.0,
"purchase_amount": 168050.0, # 实际买入金额
"commission": 350.0, # 累计佣金
"stamp_tax": 0.0, # 累计印花税(买入无)
"transfer_fee": 100.0, # 累计过户费
"trades": [...] # 每笔交易明细
}
"""
with get_session() as s:
trades = s.execute(
select(Trade).where(Trade.code == code, Trade.side == "buy")
.order_by(Trade.date)
).scalars().all()
if not trades:
return {"ok": False, "msg": "该股票无买入记录"}
is_sh = code.startswith("6")
total_purchase = 0.0
total_commission = 0.0
total_stamp = 0.0
total_transfer = 0.0
trade_details = []
for t in trades:
cost = calculate_trade_cost(t.price, t.qty, "buy", is_sh)
total_purchase += cost["amount"]
total_commission += cost["commission"]
total_stamp += cost["stamp_tax"]
total_transfer += cost["transfer_fee"]
trade_details.append({
"date": t.date.isoformat(),
"price": t.price,
"qty": t.qty,
"amount": cost["amount"],
"cost_detail": cost
})
total_cost = total_purchase + total_commission + total_stamp + total_transfer
return {
"ok": True,
"code": code,
"name": trades[0].name,
"total_cost": round(total_cost, 2),
"purchase_amount": round(total_purchase, 2),
"commission": round(total_commission, 2),
"stamp_tax": round(total_stamp, 2),
"transfer_fee": round(total_transfer, 2),
"cost_rate": round((total_cost - total_purchase) / total_purchase * 100, 4) if total_purchase > 0 else 0,
"trades": trade_details
}

88
backend/redis_cache.py Normal file
View File

@@ -0,0 +1,88 @@
"""Redis 缓存层,替代内存缓存,支持持久化和跨进程共享。"""
import json
import redis
from typing import Any, Optional
from config import REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD
class RedisCache:
def __init__(self):
self.client: Optional[redis.Redis] = None
self.enabled = False
self._connect()
def _connect(self):
"""连接 Redis失败时降级为禁用状态"""
try:
self.client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD if REDIS_PASSWORD else None,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2
)
self.client.ping()
self.enabled = True
print(f"✓ Redis 已连接: {REDIS_HOST}:{REDIS_PORT}")
except Exception as e:
self.enabled = False
print(f"✗ Redis 连接失败,缓存已禁用: {e}")
def get(self, key: str) -> Optional[Any]:
"""获取缓存,自动反序列化 JSON"""
if not self.enabled:
return None
try:
value = self.client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
print(f"Redis get error: {e}")
return None
def set(self, key: str, value: Any, expire: int = 3600):
"""设置缓存,自动序列化为 JSON
Args:
key: 缓存键
value: 缓存值可序列化为JSON的对象
expire: 过期时间默认1小时
"""
if not self.enabled:
return False
try:
serialized = json.dumps(value, ensure_ascii=False, default=str)
self.client.setex(key, expire, serialized)
return True
except Exception as e:
print(f"Redis set error: {e}")
return False
def delete(self, key: str):
"""删除缓存"""
if not self.enabled:
return False
try:
self.client.delete(key)
return True
except Exception as e:
print(f"Redis delete error: {e}")
return False
def clear_pattern(self, pattern: str):
"""批量删除匹配模式的键"""
if not self.enabled:
return 0
try:
keys = self.client.keys(pattern)
if keys:
return self.client.delete(*keys)
return 0
except Exception as e:
print(f"Redis clear_pattern error: {e}")
return 0
# 全局单例
cache = RedisCache()

View File

@@ -8,3 +8,7 @@ APScheduler>=3.10.4
psycopg2-binary>=2.9.9
jieba>=0.42.1
numpy>=1.26.0
redis>=5.0.0
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
python-multipart>=0.0.9

View File

@@ -13,6 +13,7 @@ import alerts
import report
import signals
import intraday_radar
import trade_calendar
_scheduler: BackgroundScheduler | None = None
_lock = threading.Lock()
@@ -134,11 +135,16 @@ def start_scheduler():
_job_signal_stats, CronTrigger(day_of_week="sat", hour=9, minute=0),
id="signal_stats", replace_existing=True, misfire_grace_time=7200,
)
# 盘中异动扫描(交易时间每分钟)
# 盘中异动扫描(交易时间每分钟)
_scheduler.add_job(
_safe_scan_intraday, IntervalTrigger(seconds=60),
id="intraday_scan", replace_existing=True, max_instances=1,
)
# 每日早盘前推送日历事件提醒(持仓股除权、解禁、财报等)
_scheduler.add_job(
_job_calendar_alerts, CronTrigger(day_of_week="mon-fri", hour=8, minute=30),
id="calendar_alerts", replace_existing=True, misfire_grace_time=3600,
)
_scheduler.start()
return _scheduler
@@ -154,7 +160,13 @@ def _safe_scan_intraday():
try:
result = intraday_radar.scan_all()
if result.get("count", 0) > 0:
# 有新异动时自动推送
intraday_radar.notify_events()
except Exception as e:
print("[intraday] scan error:", repr(e)[:120])
def _job_calendar_alerts():
try:
trade_calendar.check_and_push_calendar_alerts()
except Exception as e:
print("[calendar] alert error:", repr(e)[:120])

View File

@@ -0,0 +1,179 @@
"""测试三大核心功能的脚本"""
import requests
import time
import sys
BASE_URL = "http://localhost:8000"
def test_health():
"""测试健康检查接口"""
print("\n=== 1. 测试健康检查 ===")
try:
resp = requests.get(f"{BASE_URL}/api/health")
data = resp.json()
print(f"状态码: {resp.status_code}")
print(f"响应: {data}")
print(f"✓ AkShare: {'可用' if data.get('akshare') else '不可用'}")
print(f"✓ Redis: {'已连接' if data.get('redis') else '未连接(将使用内存缓存)'}")
print(f"✓ 鉴权: {'已启用' if data.get('auth') else '未启用'}")
return True
except Exception as e:
print(f"✗ 失败: {e}")
return False
def test_cache_performance():
"""测试 Redis 缓存性能"""
print("\n=== 2. 测试 Redis 缓存性能 ===")
try:
# 第一次请求(缓存未命中)
start = time.time()
resp1 = requests.get(f"{BASE_URL}/api/indices")
time1 = time.time() - start
print(f"第一次请求耗时: {time1:.3f}")
# 第二次请求(应该命中缓存)
start = time.time()
resp2 = requests.get(f"{BASE_URL}/api/indices")
time2 = time.time() - start
print(f"第二次请求耗时: {time2:.3f}")
speedup = time1 / time2 if time2 > 0 else 1
print(f"性能提升: {speedup:.1f}x")
if speedup > 2:
print("✓ Redis 缓存生效")
else:
print("⚠ 缓存可能未生效或使用内存缓存")
return True
except Exception as e:
print(f"✗ 失败: {e}")
return False
def test_auth_login():
"""测试登录功能"""
print("\n=== 3. 测试认证系统 - 登录 ===")
try:
resp = requests.post(
f"{BASE_URL}/api/auth/login",
json={"username": "admin", "password": "admin123"}
)
print(f"状态码: {resp.status_code}")
data = resp.json()
if resp.status_code == 200 and data.get("access_token"):
print(f"✓ 登录成功")
print(f" 用户名: {data.get('username')}")
print(f" 管理员: {data.get('is_admin')}")
print(f" Token: {data.get('access_token')[:50]}...")
return data.get("access_token")
else:
print(f"✗ 登录失败: {data}")
return None
except Exception as e:
print(f"✗ 失败: {e}")
return None
def test_auth_protected(token):
"""测试受保护的接口"""
print("\n=== 4. 测试认证系统 - 受保护接口 ===")
# 测试无 Token 访问
print("\n4.1 无 Token 访问管理接口:")
try:
resp = requests.get(f"{BASE_URL}/api/admin/status")
print(f"状态码: {resp.status_code}")
if resp.status_code == 401:
print("✓ 正确拦截未认证请求")
else:
print(f"⚠ 预期 401实际 {resp.status_code}")
except Exception as e:
print(f"✗ 失败: {e}")
# 测试有 Token 访问
print("\n4.2 使用 Token 访问管理接口:")
try:
resp = requests.get(
f"{BASE_URL}/api/admin/status",
headers={"Authorization": f"Bearer {token}"}
)
print(f"状态码: {resp.status_code}")
if resp.status_code == 200:
data = resp.json()
print("✓ Token 认证成功")
print(f" 数据库记录数: {data.get('counts', {})}")
return True
else:
print(f"✗ 认证失败: {resp.json()}")
return False
except Exception as e:
print(f"✗ 失败: {e}")
return False
def test_exception_handling():
"""测试异常处理"""
print("\n=== 5. 测试统一异常处理 ===")
# 测试参数验证错误
print("\n5.1 测试参数验证错误:")
try:
resp = requests.get(f"{BASE_URL}/api/kline?days=invalid")
print(f"状态码: {resp.status_code}")
data = resp.json()
if resp.status_code == 422:
print("✓ 参数验证错误处理正确")
print(f" 错误信息: {data.get('error')}")
else:
print(f"⚠ 预期 422实际 {resp.status_code}")
except Exception as e:
print(f"✗ 失败: {e}")
# 测试业务逻辑错误
print("\n5.2 测试业务逻辑错误:")
try:
resp = requests.get(f"{BASE_URL}/api/backtest?symbol=600519&fast=20&slow=10")
print(f"状态码: {resp.status_code}")
data = resp.json()
if not data.get('ok'):
print("✓ 业务错误处理正确")
print(f" 错误信息: {data.get('msg')}")
except Exception as e:
print(f"✗ 失败: {e}")
return True
def main():
print("="*50)
print("三大核心功能测试")
print("="*50)
# 检查服务是否运行
try:
requests.get(f"{BASE_URL}/api/health", timeout=2)
except:
print(f"\n✗ 无法连接到服务: {BASE_URL}")
print("请确保服务已启动: python main.py")
sys.exit(1)
# 运行测试
test_health()
test_cache_performance()
token = test_auth_login()
if token:
test_auth_protected(token)
else:
print("\n⚠ 跳过受保护接口测试(登录失败)")
test_exception_handling()
print("\n" + "="*50)
print("测试完成!")
print("="*50)
print("\n下一步:")
print("1. 如果 Redis 显示'未连接',请安装并启动 Redis")
print("2. 如果登录失败,请运行: python cli.py init")
print("3. 登录成功后,务必修改默认密码")
print("4. 生产环境请修改 .env 中的 SECRET_KEY")
if __name__ == "__main__":
main()

338
backend/trade_calendar.py Normal file
View File

@@ -0,0 +1,338 @@
"""交易日历与关键事件提醒"""
import datetime as dt
from typing import List, Optional, Dict
from sqlalchemy import select, and_
from db import get_session
from models import Trade, Security, CorporateEvent, AlertEvent
import akshare_service as svc
try:
import akshare as ak
AK_OK = True
except Exception:
ak = None
AK_OK = False
def get_upcoming_dividends(days_ahead: int = 30) -> Dict:
"""获取即将到来的除权除息日"""
today = dt.date.today()
end = today + dt.timedelta(days=days_ahead)
# 获取持仓股票代码
with get_session() as s:
trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all()
# 计算当前持仓
pos = {}
for t in trades:
if t.code not in pos:
pos[t.code] = {"qty": 0, "name": t.name}
if t.side == "buy":
pos[t.code]["qty"] += t.qty
else:
pos[t.code]["qty"] = max(0, pos[t.code]["qty"] - t.qty)
holding_codes = [c for c, v in pos.items() if v["qty"] > 0]
events = []
if AK_OK and holding_codes:
try:
df = ak.stock_zh_a_dividend()
if df is not None and not df.empty:
for _, r in df.iterrows():
code = str(r.get("代码", ""))
if code not in holding_codes:
continue
ex_date_str = str(r.get("除权除息日", ""))
if not ex_date_str or ex_date_str == "nan":
continue
try:
ex_date = dt.date.fromisoformat(ex_date_str[:10])
if today <= ex_date <= end:
days_left = (ex_date - today).days
events.append({
"code": code,
"name": pos[code]["name"],
"event_type": "除权除息",
"event_date": ex_date.isoformat(),
"days_left": days_left,
"detail": f"送股: {r.get('送股', 0)}, 转增: {r.get('转增', 0)}, 派息: {r.get('派息', 0)}",
"is_holding": True,
"urgency": "high" if days_left <= 3 else "medium"
})
except Exception:
continue
except Exception:
pass
# 补充从数据库中获取的事件
with get_session() as s:
db_events = s.execute(
select(CorporateEvent).where(
and_(
CorporateEvent.event_type == "dividend",
CorporateEvent.event_date >= today,
CorporateEvent.event_date <= end
)
).order_by(CorporateEvent.event_date)
).scalars().all()
for e in db_events:
if any(ev["code"] == e.code for ev in events):
continue
days_left = (e.event_date - today).days
events.append({
"code": e.code,
"name": e.name,
"event_type": "除权除息",
"event_date": e.event_date.isoformat(),
"days_left": days_left,
"detail": e.description,
"is_holding": e.code in holding_codes,
"urgency": "high" if days_left <= 3 else "medium"
})
events.sort(key=lambda x: x["event_date"])
return {"ok": True, "events": events, "count": len(events)}
def get_unlock_calendar(days_ahead: int = 90) -> Dict:
"""获取限售解禁日历"""
today = dt.date.today()
end = today + dt.timedelta(days=days_ahead)
events = []
if AK_OK:
try:
df = ak.stock_restricted_release_summary_em()
if df is not None and not df.empty:
for _, r in df.head(50).iterrows():
date_str = str(r.get("解禁日期", ""))
if not date_str or date_str == "nan":
continue
try:
unlock_date = dt.date.fromisoformat(date_str[:10])
if today <= unlock_date <= end:
amount = float(r.get("解禁数量", 0) or 0)
market_val = float(r.get("解禁市值", 0) or 0)
events.append({
"code": str(r.get("代码", "")),
"name": str(r.get("名称", "")),
"event_type": "限售解禁",
"event_date": unlock_date.isoformat(),
"days_left": (unlock_date - today).days,
"detail": f"解禁市值: {round(market_val/1e8, 2)}亿",
"amount_billion": round(market_val / 1e8, 2),
"urgency": "high" if market_val >= 10e8 else "medium"
})
except Exception:
continue
except Exception:
pass
# 从数据库补充
with get_session() as s:
db_events = s.execute(
select(CorporateEvent).where(
and_(
CorporateEvent.event_type == "unlock",
CorporateEvent.event_date >= today,
CorporateEvent.event_date <= end
)
).order_by(CorporateEvent.event_date)
).scalars().all()
for e in db_events:
if any(ev["code"] == e.code for ev in events):
continue
events.append({
"code": e.code,
"name": e.name,
"event_type": "限售解禁",
"event_date": e.event_date.isoformat(),
"days_left": (e.event_date - today).days,
"detail": e.description,
"amount_billion": e.amount,
"urgency": "high" if e.amount >= 10 else "medium"
})
events.sort(key=lambda x: x["event_date"])
return {"ok": True, "events": events, "count": len(events)}
def get_earnings_calendar(days_ahead: int = 30, holding_only: bool = False) -> Dict:
"""获取财报披露日历"""
today = dt.date.today()
end = today + dt.timedelta(days=days_ahead)
# 获取持仓代码
holding_codes = set()
if holding_only:
with get_session() as s:
trades = s.execute(select(Trade).order_by(Trade.date)).scalars().all()
pos = {}
for t in trades:
if t.code not in pos:
pos[t.code] = 0
pos[t.code] += t.qty if t.side == "buy" else -t.qty
holding_codes = {c for c, q in pos.items() if q > 0}
events = []
if AK_OK:
try:
df = ak.stock_notice_report()
if df is not None and not df.empty:
for _, r in df.head(100).iterrows():
date_str = str(r.get("公告日期", ""))
code = str(r.get("代码", ""))
if not date_str or date_str == "nan":
continue
if holding_only and code not in holding_codes:
continue
try:
report_date = dt.date.fromisoformat(date_str[:10])
if today <= report_date <= end:
events.append({
"code": code,
"name": str(r.get("名称", "")),
"event_type": "财报披露",
"event_date": report_date.isoformat(),
"days_left": (report_date - today).days,
"report_type": str(r.get("公告类型", "")),
"is_holding": code in holding_codes,
"urgency": "high" if code in holding_codes else "low"
})
except Exception:
continue
except Exception:
pass
# 从数据库补充
with get_session() as s:
db_events = s.execute(
select(CorporateEvent).where(
and_(
CorporateEvent.event_type == "earnings",
CorporateEvent.event_date >= today,
CorporateEvent.event_date <= end
)
).order_by(CorporateEvent.event_date)
).scalars().all()
for e in db_events:
if any(ev["code"] == e.code for ev in events):
continue
if holding_only and e.code not in holding_codes:
continue
events.append({
"code": e.code,
"name": e.name,
"event_type": "财报披露",
"event_date": e.event_date.isoformat(),
"days_left": (e.event_date - today).days,
"report_type": e.title,
"is_holding": e.code in holding_codes,
"urgency": "high" if e.code in holding_codes else "low"
})
events.sort(key=lambda x: x["event_date"])
return {"ok": True, "events": events, "count": len(events)}
def get_all_upcoming_events(days_ahead: int = 30) -> Dict:
"""获取所有即将到来的关键事件(综合视图)"""
today = dt.date.today()
all_events = []
# 合并所有事件
for result in [
get_upcoming_dividends(days_ahead),
get_earnings_calendar(days_ahead),
get_unlock_calendar(days_ahead)
]:
all_events.extend(result.get("events", []))
# 按日期排序
all_events.sort(key=lambda x: x["event_date"])
# 按日期分组
grouped = {}
for event in all_events:
date = event["event_date"]
if date not in grouped:
grouped[date] = []
grouped[date].append(event)
# 生成日历视图
calendar = []
for date_str, events in sorted(grouped.items()):
date = dt.date.fromisoformat(date_str)
calendar.append({
"date": date_str,
"weekday": ["周一", "周二", "周三", "周四", "周五", "周六", "周日"][date.weekday()],
"days_left": (date - today).days,
"events": events,
"has_high_urgency": any(e["urgency"] == "high" for e in events),
"has_holding": any(e.get("is_holding", False) for e in events)
})
# 紧急事件3天内
urgent = [e for e in all_events if e.get("days_left", 99) <= 3]
return {
"ok": True,
"calendar": calendar,
"urgent": urgent,
"total": len(all_events),
"summary": {
"dividends": len([e for e in all_events if e["event_type"] == "除权除息"]),
"earnings": len([e for e in all_events if e["event_type"] == "财报披露"]),
"unlocks": len([e for e in all_events if e["event_type"] == "限售解禁"]),
"urgent": len(urgent)
}
}
def check_and_push_calendar_alerts() -> Dict:
"""检查并推送日历事件预警(定时任务调用)"""
try:
from notifier import notify
except Exception:
return {"ok": False, "msg": "推送模块不可用"}
result = get_all_upcoming_events(days_ahead=7)
urgent = result.get("urgent", [])
if not urgent:
return {"ok": True, "msg": "无紧急事件", "pushed": 0}
# 生成推送内容
lines = [f"📅 未来7天关键事件提醒{len(urgent)}条)\n"]
for event in urgent[:10]: # 最多推送10条
urgency_icon = "🔴" if event["urgency"] == "high" else "🟡"
holding_icon = "💰" if event.get("is_holding") else ""
lines.append(
f"{urgency_icon}{holding_icon} {event['event_date']} "
f"{event['name']}({event['code']}) "
f"{event['event_type']} "
f"({event['days_left']}天后)"
)
message = "\n".join(lines)
notify("【Blackdata】关键事件提醒", message)
# 写入站内通知
with get_session() as s:
for event in urgent[:10]:
alert = AlertEvent(
rule_id=0,
code=event["code"],
name=event["name"],
message=f"{event['event_type']}: {event.get('detail', '')} ({event['days_left']}天后)",
value=event.get("amount_billion", 0)
)
s.add(alert)
s.commit()
return {"ok": True, "pushed": len(urgent), "msg": f"已推送 {len(urgent)} 条事件提醒"}

View File

@@ -0,0 +1,345 @@
"""自选股分组管理"""
import datetime as dt
from typing import List, Dict, Optional
from sqlalchemy import select, func
from db import get_session
from models import WatchlistGroup, WatchlistItem, Security
import akshare_service as svc
# 预设分组
DEFAULT_GROUPS = [
{"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True},
{"name": "观察池", "description": "待观察的潜力股", "color": "blue"},
{"name": "持仓股", "description": "当前持仓的股票", "color": "green"},
{"name": "概念股", "description": "热门概念板块", "color": "purple"},
]
def init_default_groups():
"""初始化默认分组(如果不存在)"""
with get_session() as s:
count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar()
if count == 0:
for idx, g in enumerate(DEFAULT_GROUPS):
group = WatchlistGroup(
name=g["name"],
description=g["description"],
color=g["color"],
is_default=g.get("is_default", False),
sort_order=idx
)
s.add(group)
s.commit()
print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)}")
return True
def get_all_groups() -> List[Dict]:
"""获取所有分组"""
with get_session() as s:
groups = s.execute(
select(WatchlistGroup).order_by(WatchlistGroup.sort_order)
).scalars().all()
result = []
for g in groups:
# 统计分组内股票数量
count = s.execute(
select(func.count()).select_from(WatchlistItem)
.where(WatchlistItem.group_id == g.id)
).scalar()
result.append({
"id": g.id,
"name": g.name,
"description": g.description,
"color": g.color,
"count": count,
"is_default": g.is_default,
"sort_order": g.sort_order
})
return result
def create_group(name: str, description: str = "", color: str = "blue") -> Dict:
"""创建新分组"""
with get_session() as s:
# 获取当前最大排序号
max_order = s.execute(
select(func.max(WatchlistGroup.sort_order))
).scalar() or 0
group = WatchlistGroup(
name=name,
description=description,
color=color,
sort_order=max_order + 1
)
s.add(group)
s.commit()
return {
"ok": True,
"id": group.id,
"name": group.name
}
def update_group(group_id: int, name: Optional[str] = None,
description: Optional[str] = None, color: Optional[str] = None) -> Dict:
"""更新分组信息"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
if name is not None:
group.name = name
if description is not None:
group.description = description
if color is not None:
group.color = color
s.commit()
return {"ok": True}
def delete_group(group_id: int) -> Dict:
"""删除分组(同时删除分组内的股票)"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
if group.is_default:
return {"ok": False, "msg": "默认分组不能删除"}
# 删除分组内的股票
s.execute(
WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id)
)
# 删除分组
s.delete(group)
s.commit()
return {"ok": True}
def reorder_groups(group_ids: List[int]) -> Dict:
"""重新排序分组"""
with get_session() as s:
for idx, gid in enumerate(group_ids):
group = s.get(WatchlistGroup, gid)
if group:
group.sort_order = idx
s.commit()
return {"ok": True}
def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict:
"""获取分组内的股票列表"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
items = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id)
.order_by(WatchlistItem.sort_order)
).scalars().all()
codes = [item.code for item in items]
# 获取实时行情
stocks = []
if with_quotes and codes:
quotes_data = svc.get_watchlist(codes)
quotes_map = {s["code"]: s for s in quotes_data.get("list", [])}
for item in items:
quote = quotes_map.get(item.code, {})
stocks.append({
"id": item.id,
"code": item.code,
"name": item.name or quote.get("name", ""),
"price": quote.get("price", 0),
"pct": quote.get("pct", 0),
"change": quote.get("change", 0),
"amount": quote.get("amount", 0),
"note": item.note,
"added_at": item.added_at.strftime("%Y-%m-%d")
})
else:
for item in items:
stocks.append({
"id": item.id,
"code": item.code,
"name": item.name,
"note": item.note,
"added_at": item.added_at.strftime("%Y-%m-%d")
})
return {
"ok": True,
"group": {
"id": group.id,
"name": group.name,
"description": group.description,
"color": group.color
},
"stocks": stocks
}
def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict:
"""添加股票到分组"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
# 检查是否已存在
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
).scalar_one_or_none()
if exists:
return {"ok": False, "msg": "该股票已在分组中"}
# 获取股票名称
sec = s.get(Security, code)
name = sec.name if sec else code
# 获取当前最大排序号
max_order = s.execute(
select(func.max(WatchlistItem.sort_order))
.where(WatchlistItem.group_id == group_id)
).scalar() or 0
item = WatchlistItem(
group_id=group_id,
code=code,
name=name,
note=note,
sort_order=max_order + 1
)
s.add(item)
s.commit()
return {"ok": True, "id": item.id}
def remove_stock_from_group(item_id: int) -> Dict:
"""从分组中移除股票"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
s.delete(item)
s.commit()
return {"ok": True}
def move_stock_to_group(item_id: int, target_group_id: int) -> Dict:
"""将股票移动到另一个分组"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
target_group = s.get(WatchlistGroup, target_group_id)
if not target_group:
return {"ok": False, "msg": "目标分组不存在"}
# 检查目标分组是否已有该股票
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == target_group_id,
WatchlistItem.code == item.code)
).scalar_one_or_none()
if exists:
return {"ok": False, "msg": "目标分组已有该股票"}
item.group_id = target_group_id
s.commit()
return {"ok": True}
def batch_add_stocks(group_id: int, codes: List[str]) -> Dict:
"""批量添加股票到分组"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
added = 0
skipped = 0
for code in codes:
# 检查是否已存在
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
).scalar_one_or_none()
if exists:
skipped += 1
continue
# 获取股票名称
sec = s.get(Security, code)
name = sec.name if sec else code
item = WatchlistItem(
group_id=group_id,
code=code,
name=name,
sort_order=added
)
s.add(item)
added += 1
s.commit()
return {"ok": True, "added": added, "skipped": skipped}
def update_stock_note(item_id: int, note: str) -> Dict:
"""更新股票备注"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
item.note = note
s.commit()
return {"ok": True}
def reorder_stocks(item_ids: List[int]) -> Dict:
"""重新排序分组内的股票"""
with get_session() as s:
for idx, item_id in enumerate(item_ids):
item = s.get(WatchlistItem, item_id)
if item:
item.sort_order = idx
s.commit()
return {"ok": True}
def search_stocks_across_groups(keyword: str) -> List[Dict]:
"""跨分组搜索股票"""
with get_session() as s:
items = s.execute(
select(WatchlistItem, WatchlistGroup)
.join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id)
.where(
(WatchlistItem.code.like(f"%{keyword}%")) |
(WatchlistItem.name.like(f"%{keyword}%"))
)
).all()
result = []
for item, group in items:
result.append({
"id": item.id,
"code": item.code,
"name": item.name,
"group_id": group.id,
"group_name": group.name,
"group_color": group.color,
"note": item.note
})
return result