"""自选股分组管理""" import datetime as dt from typing import List, Dict, Optional from sqlalchemy import select, func from db import get_session from models import WatchlistGroup, WatchlistItem, Security import akshare_service as svc # 预设分组 DEFAULT_GROUPS = [ {"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True}, {"name": "观察池", "description": "待观察的潜力股", "color": "blue"}, {"name": "持仓股", "description": "当前持仓的股票", "color": "green"}, {"name": "概念股", "description": "热门概念板块", "color": "purple"}, ] def init_default_groups(): """初始化默认分组(如果不存在)""" with get_session() as s: count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar() if count == 0: for idx, g in enumerate(DEFAULT_GROUPS): group = WatchlistGroup( name=g["name"], description=g["description"], color=g["color"], is_default=g.get("is_default", False), sort_order=idx ) s.add(group) s.commit() print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)} 个") return True def get_all_groups() -> List[Dict]: """获取所有分组""" with get_session() as s: groups = s.execute( select(WatchlistGroup).order_by(WatchlistGroup.sort_order) ).scalars().all() result = [] for g in groups: # 统计分组内股票数量 count = s.execute( select(func.count()).select_from(WatchlistItem) .where(WatchlistItem.group_id == g.id) ).scalar() result.append({ "id": g.id, "name": g.name, "description": g.description, "color": g.color, "count": count, "is_default": g.is_default, "sort_order": g.sort_order }) return result def create_group(name: str, description: str = "", color: str = "blue") -> Dict: """创建新分组""" with get_session() as s: # 获取当前最大排序号 max_order = s.execute( select(func.max(WatchlistGroup.sort_order)) ).scalar() or 0 group = WatchlistGroup( name=name, description=description, color=color, sort_order=max_order + 1 ) s.add(group) s.commit() return { "ok": True, "id": group.id, "name": group.name } def update_group(group_id: int, name: Optional[str] = None, description: Optional[str] = None, color: Optional[str] = None) -> Dict: """更新分组信息""" with get_session() as s: group = s.get(WatchlistGroup, group_id) if not group: return {"ok": False, "msg": "分组不存在"} if name is not None: group.name = name if description is not None: group.description = description if color is not None: group.color = color s.commit() return {"ok": True} def delete_group(group_id: int) -> Dict: """删除分组(同时删除分组内的股票)""" with get_session() as s: group = s.get(WatchlistGroup, group_id) if not group: return {"ok": False, "msg": "分组不存在"} if group.is_default: return {"ok": False, "msg": "默认分组不能删除"} # 删除分组内的股票 s.execute( WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id) ) # 删除分组 s.delete(group) s.commit() return {"ok": True} def reorder_groups(group_ids: List[int]) -> Dict: """重新排序分组""" with get_session() as s: for idx, gid in enumerate(group_ids): group = s.get(WatchlistGroup, gid) if group: group.sort_order = idx s.commit() return {"ok": True} def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict: """获取分组内的股票列表""" with get_session() as s: group = s.get(WatchlistGroup, group_id) if not group: return {"ok": False, "msg": "分组不存在"} items = s.execute( select(WatchlistItem) .where(WatchlistItem.group_id == group_id) .order_by(WatchlistItem.sort_order) ).scalars().all() codes = [item.code for item in items] # 获取实时行情 stocks = [] if with_quotes and codes: quotes_data = svc.get_watchlist(codes) quotes_map = {s["code"]: s for s in quotes_data.get("list", [])} for item in items: quote = quotes_map.get(item.code, {}) stocks.append({ "id": item.id, "code": item.code, "name": item.name or quote.get("name", ""), "price": quote.get("price", 0), "pct": quote.get("pct", 0), "change": quote.get("change", 0), "amount": quote.get("amount", 0), "note": item.note, "added_at": item.added_at.strftime("%Y-%m-%d") }) else: for item in items: stocks.append({ "id": item.id, "code": item.code, "name": item.name, "note": item.note, "added_at": item.added_at.strftime("%Y-%m-%d") }) return { "ok": True, "group": { "id": group.id, "name": group.name, "description": group.description, "color": group.color }, "stocks": stocks } def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict: """添加股票到分组""" with get_session() as s: group = s.get(WatchlistGroup, group_id) if not group: return {"ok": False, "msg": "分组不存在"} # 检查是否已存在 exists = s.execute( select(WatchlistItem) .where(WatchlistItem.group_id == group_id, WatchlistItem.code == code) ).scalar_one_or_none() if exists: return {"ok": False, "msg": "该股票已在分组中"} # 获取股票名称 sec = s.get(Security, code) name = sec.name if sec else code # 获取当前最大排序号 max_order = s.execute( select(func.max(WatchlistItem.sort_order)) .where(WatchlistItem.group_id == group_id) ).scalar() or 0 item = WatchlistItem( group_id=group_id, code=code, name=name, note=note, sort_order=max_order + 1 ) s.add(item) s.commit() return {"ok": True, "id": item.id} def remove_stock_from_group(item_id: int) -> Dict: """从分组中移除股票""" with get_session() as s: item = s.get(WatchlistItem, item_id) if not item: return {"ok": False, "msg": "股票不存在"} s.delete(item) s.commit() return {"ok": True} def move_stock_to_group(item_id: int, target_group_id: int) -> Dict: """将股票移动到另一个分组""" with get_session() as s: item = s.get(WatchlistItem, item_id) if not item: return {"ok": False, "msg": "股票不存在"} target_group = s.get(WatchlistGroup, target_group_id) if not target_group: return {"ok": False, "msg": "目标分组不存在"} # 检查目标分组是否已有该股票 exists = s.execute( select(WatchlistItem) .where(WatchlistItem.group_id == target_group_id, WatchlistItem.code == item.code) ).scalar_one_or_none() if exists: return {"ok": False, "msg": "目标分组已有该股票"} item.group_id = target_group_id s.commit() return {"ok": True} def batch_add_stocks(group_id: int, codes: List[str]) -> Dict: """批量添加股票到分组""" with get_session() as s: group = s.get(WatchlistGroup, group_id) if not group: return {"ok": False, "msg": "分组不存在"} added = 0 skipped = 0 for code in codes: # 检查是否已存在 exists = s.execute( select(WatchlistItem) .where(WatchlistItem.group_id == group_id, WatchlistItem.code == code) ).scalar_one_or_none() if exists: skipped += 1 continue # 获取股票名称 sec = s.get(Security, code) name = sec.name if sec else code item = WatchlistItem( group_id=group_id, code=code, name=name, sort_order=added ) s.add(item) added += 1 s.commit() return {"ok": True, "added": added, "skipped": skipped} def update_stock_note(item_id: int, note: str) -> Dict: """更新股票备注""" with get_session() as s: item = s.get(WatchlistItem, item_id) if not item: return {"ok": False, "msg": "股票不存在"} item.note = note s.commit() return {"ok": True} def reorder_stocks(item_ids: List[int]) -> Dict: """重新排序分组内的股票""" with get_session() as s: for idx, item_id in enumerate(item_ids): item = s.get(WatchlistItem, item_id) if item: item.sort_order = idx s.commit() return {"ok": True} def search_stocks_across_groups(keyword: str) -> List[Dict]: """跨分组搜索股票""" with get_session() as s: items = s.execute( select(WatchlistItem, WatchlistGroup) .join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id) .where( (WatchlistItem.code.like(f"%{keyword}%")) | (WatchlistItem.name.like(f"%{keyword}%")) ) ).all() result = [] for item, group in items: result.append({ "id": item.id, "code": item.code, "name": item.name, "group_id": group.id, "group_name": group.name, "group_color": group.color, "note": item.note }) return result