功能细节优化
This commit is contained in:
345
backend/watchlist_manager.py
Normal file
345
backend/watchlist_manager.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""自选股分组管理"""
|
||||
import datetime as dt
|
||||
from typing import List, Dict, Optional
|
||||
from sqlalchemy import select, func
|
||||
from db import get_session
|
||||
from models import WatchlistGroup, WatchlistItem, Security
|
||||
import akshare_service as svc
|
||||
|
||||
# 预设分组
|
||||
DEFAULT_GROUPS = [
|
||||
{"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True},
|
||||
{"name": "观察池", "description": "待观察的潜力股", "color": "blue"},
|
||||
{"name": "持仓股", "description": "当前持仓的股票", "color": "green"},
|
||||
{"name": "概念股", "description": "热门概念板块", "color": "purple"},
|
||||
]
|
||||
|
||||
def init_default_groups():
|
||||
"""初始化默认分组(如果不存在)"""
|
||||
with get_session() as s:
|
||||
count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar()
|
||||
if count == 0:
|
||||
for idx, g in enumerate(DEFAULT_GROUPS):
|
||||
group = WatchlistGroup(
|
||||
name=g["name"],
|
||||
description=g["description"],
|
||||
color=g["color"],
|
||||
is_default=g.get("is_default", False),
|
||||
sort_order=idx
|
||||
)
|
||||
s.add(group)
|
||||
s.commit()
|
||||
print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)} 个")
|
||||
return True
|
||||
|
||||
def get_all_groups() -> List[Dict]:
|
||||
"""获取所有分组"""
|
||||
with get_session() as s:
|
||||
groups = s.execute(
|
||||
select(WatchlistGroup).order_by(WatchlistGroup.sort_order)
|
||||
).scalars().all()
|
||||
|
||||
result = []
|
||||
for g in groups:
|
||||
# 统计分组内股票数量
|
||||
count = s.execute(
|
||||
select(func.count()).select_from(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == g.id)
|
||||
).scalar()
|
||||
|
||||
result.append({
|
||||
"id": g.id,
|
||||
"name": g.name,
|
||||
"description": g.description,
|
||||
"color": g.color,
|
||||
"count": count,
|
||||
"is_default": g.is_default,
|
||||
"sort_order": g.sort_order
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def create_group(name: str, description: str = "", color: str = "blue") -> Dict:
|
||||
"""创建新分组"""
|
||||
with get_session() as s:
|
||||
# 获取当前最大排序号
|
||||
max_order = s.execute(
|
||||
select(func.max(WatchlistGroup.sort_order))
|
||||
).scalar() or 0
|
||||
|
||||
group = WatchlistGroup(
|
||||
name=name,
|
||||
description=description,
|
||||
color=color,
|
||||
sort_order=max_order + 1
|
||||
)
|
||||
s.add(group)
|
||||
s.commit()
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"id": group.id,
|
||||
"name": group.name
|
||||
}
|
||||
|
||||
def update_group(group_id: int, name: Optional[str] = None,
|
||||
description: Optional[str] = None, color: Optional[str] = None) -> Dict:
|
||||
"""更新分组信息"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
if name is not None:
|
||||
group.name = name
|
||||
if description is not None:
|
||||
group.description = description
|
||||
if color is not None:
|
||||
group.color = color
|
||||
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def delete_group(group_id: int) -> Dict:
|
||||
"""删除分组(同时删除分组内的股票)"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
if group.is_default:
|
||||
return {"ok": False, "msg": "默认分组不能删除"}
|
||||
|
||||
# 删除分组内的股票
|
||||
s.execute(
|
||||
WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id)
|
||||
)
|
||||
|
||||
# 删除分组
|
||||
s.delete(group)
|
||||
s.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
def reorder_groups(group_ids: List[int]) -> Dict:
|
||||
"""重新排序分组"""
|
||||
with get_session() as s:
|
||||
for idx, gid in enumerate(group_ids):
|
||||
group = s.get(WatchlistGroup, gid)
|
||||
if group:
|
||||
group.sort_order = idx
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict:
|
||||
"""获取分组内的股票列表"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
items = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id)
|
||||
.order_by(WatchlistItem.sort_order)
|
||||
).scalars().all()
|
||||
|
||||
codes = [item.code for item in items]
|
||||
|
||||
# 获取实时行情
|
||||
stocks = []
|
||||
if with_quotes and codes:
|
||||
quotes_data = svc.get_watchlist(codes)
|
||||
quotes_map = {s["code"]: s for s in quotes_data.get("list", [])}
|
||||
|
||||
for item in items:
|
||||
quote = quotes_map.get(item.code, {})
|
||||
stocks.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name or quote.get("name", ""),
|
||||
"price": quote.get("price", 0),
|
||||
"pct": quote.get("pct", 0),
|
||||
"change": quote.get("change", 0),
|
||||
"amount": quote.get("amount", 0),
|
||||
"note": item.note,
|
||||
"added_at": item.added_at.strftime("%Y-%m-%d")
|
||||
})
|
||||
else:
|
||||
for item in items:
|
||||
stocks.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name,
|
||||
"note": item.note,
|
||||
"added_at": item.added_at.strftime("%Y-%m-%d")
|
||||
})
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"color": group.color
|
||||
},
|
||||
"stocks": stocks
|
||||
}
|
||||
|
||||
def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict:
|
||||
"""添加股票到分组"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
# 检查是否已存在
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
return {"ok": False, "msg": "该股票已在分组中"}
|
||||
|
||||
# 获取股票名称
|
||||
sec = s.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
|
||||
# 获取当前最大排序号
|
||||
max_order = s.execute(
|
||||
select(func.max(WatchlistItem.sort_order))
|
||||
.where(WatchlistItem.group_id == group_id)
|
||||
).scalar() or 0
|
||||
|
||||
item = WatchlistItem(
|
||||
group_id=group_id,
|
||||
code=code,
|
||||
name=name,
|
||||
note=note,
|
||||
sort_order=max_order + 1
|
||||
)
|
||||
s.add(item)
|
||||
s.commit()
|
||||
|
||||
return {"ok": True, "id": item.id}
|
||||
|
||||
def remove_stock_from_group(item_id: int) -> Dict:
|
||||
"""从分组中移除股票"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
s.delete(item)
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def move_stock_to_group(item_id: int, target_group_id: int) -> Dict:
|
||||
"""将股票移动到另一个分组"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
target_group = s.get(WatchlistGroup, target_group_id)
|
||||
if not target_group:
|
||||
return {"ok": False, "msg": "目标分组不存在"}
|
||||
|
||||
# 检查目标分组是否已有该股票
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == target_group_id,
|
||||
WatchlistItem.code == item.code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
return {"ok": False, "msg": "目标分组已有该股票"}
|
||||
|
||||
item.group_id = target_group_id
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def batch_add_stocks(group_id: int, codes: List[str]) -> Dict:
|
||||
"""批量添加股票到分组"""
|
||||
with get_session() as s:
|
||||
group = s.get(WatchlistGroup, group_id)
|
||||
if not group:
|
||||
return {"ok": False, "msg": "分组不存在"}
|
||||
|
||||
added = 0
|
||||
skipped = 0
|
||||
|
||||
for code in codes:
|
||||
# 检查是否已存在
|
||||
exists = s.execute(
|
||||
select(WatchlistItem)
|
||||
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 获取股票名称
|
||||
sec = s.get(Security, code)
|
||||
name = sec.name if sec else code
|
||||
|
||||
item = WatchlistItem(
|
||||
group_id=group_id,
|
||||
code=code,
|
||||
name=name,
|
||||
sort_order=added
|
||||
)
|
||||
s.add(item)
|
||||
added += 1
|
||||
|
||||
s.commit()
|
||||
return {"ok": True, "added": added, "skipped": skipped}
|
||||
|
||||
def update_stock_note(item_id: int, note: str) -> Dict:
|
||||
"""更新股票备注"""
|
||||
with get_session() as s:
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if not item:
|
||||
return {"ok": False, "msg": "股票不存在"}
|
||||
|
||||
item.note = note
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def reorder_stocks(item_ids: List[int]) -> Dict:
|
||||
"""重新排序分组内的股票"""
|
||||
with get_session() as s:
|
||||
for idx, item_id in enumerate(item_ids):
|
||||
item = s.get(WatchlistItem, item_id)
|
||||
if item:
|
||||
item.sort_order = idx
|
||||
s.commit()
|
||||
return {"ok": True}
|
||||
|
||||
def search_stocks_across_groups(keyword: str) -> List[Dict]:
|
||||
"""跨分组搜索股票"""
|
||||
with get_session() as s:
|
||||
items = s.execute(
|
||||
select(WatchlistItem, WatchlistGroup)
|
||||
.join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id)
|
||||
.where(
|
||||
(WatchlistItem.code.like(f"%{keyword}%")) |
|
||||
(WatchlistItem.name.like(f"%{keyword}%"))
|
||||
)
|
||||
).all()
|
||||
|
||||
result = []
|
||||
for item, group in items:
|
||||
result.append({
|
||||
"id": item.id,
|
||||
"code": item.code,
|
||||
"name": item.name,
|
||||
"group_id": group.id,
|
||||
"group_name": group.name,
|
||||
"group_color": group.color,
|
||||
"note": item.note
|
||||
})
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user