功能细节优化

This commit is contained in:
2026-06-15 01:26:39 +08:00
parent e524a3589a
commit 964c17c200
33 changed files with 6990 additions and 210 deletions

View File

@@ -0,0 +1,345 @@
"""自选股分组管理"""
import datetime as dt
from typing import List, Dict, Optional
from sqlalchemy import select, func
from db import get_session
from models import WatchlistGroup, WatchlistItem, Security
import akshare_service as svc
# 预设分组
DEFAULT_GROUPS = [
{"name": "核心自选", "description": "重点关注的核心股票", "color": "red", "is_default": True},
{"name": "观察池", "description": "待观察的潜力股", "color": "blue"},
{"name": "持仓股", "description": "当前持仓的股票", "color": "green"},
{"name": "概念股", "description": "热门概念板块", "color": "purple"},
]
def init_default_groups():
"""初始化默认分组(如果不存在)"""
with get_session() as s:
count = s.execute(select(func.count()).select_from(WatchlistGroup)).scalar()
if count == 0:
for idx, g in enumerate(DEFAULT_GROUPS):
group = WatchlistGroup(
name=g["name"],
description=g["description"],
color=g["color"],
is_default=g.get("is_default", False),
sort_order=idx
)
s.add(group)
s.commit()
print(f"✓ 创建默认自选股分组: {len(DEFAULT_GROUPS)}")
return True
def get_all_groups() -> List[Dict]:
"""获取所有分组"""
with get_session() as s:
groups = s.execute(
select(WatchlistGroup).order_by(WatchlistGroup.sort_order)
).scalars().all()
result = []
for g in groups:
# 统计分组内股票数量
count = s.execute(
select(func.count()).select_from(WatchlistItem)
.where(WatchlistItem.group_id == g.id)
).scalar()
result.append({
"id": g.id,
"name": g.name,
"description": g.description,
"color": g.color,
"count": count,
"is_default": g.is_default,
"sort_order": g.sort_order
})
return result
def create_group(name: str, description: str = "", color: str = "blue") -> Dict:
"""创建新分组"""
with get_session() as s:
# 获取当前最大排序号
max_order = s.execute(
select(func.max(WatchlistGroup.sort_order))
).scalar() or 0
group = WatchlistGroup(
name=name,
description=description,
color=color,
sort_order=max_order + 1
)
s.add(group)
s.commit()
return {
"ok": True,
"id": group.id,
"name": group.name
}
def update_group(group_id: int, name: Optional[str] = None,
description: Optional[str] = None, color: Optional[str] = None) -> Dict:
"""更新分组信息"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
if name is not None:
group.name = name
if description is not None:
group.description = description
if color is not None:
group.color = color
s.commit()
return {"ok": True}
def delete_group(group_id: int) -> Dict:
"""删除分组(同时删除分组内的股票)"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
if group.is_default:
return {"ok": False, "msg": "默认分组不能删除"}
# 删除分组内的股票
s.execute(
WatchlistItem.__table__.delete().where(WatchlistItem.group_id == group_id)
)
# 删除分组
s.delete(group)
s.commit()
return {"ok": True}
def reorder_groups(group_ids: List[int]) -> Dict:
"""重新排序分组"""
with get_session() as s:
for idx, gid in enumerate(group_ids):
group = s.get(WatchlistGroup, gid)
if group:
group.sort_order = idx
s.commit()
return {"ok": True}
def get_group_stocks(group_id: int, with_quotes: bool = True) -> Dict:
"""获取分组内的股票列表"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
items = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id)
.order_by(WatchlistItem.sort_order)
).scalars().all()
codes = [item.code for item in items]
# 获取实时行情
stocks = []
if with_quotes and codes:
quotes_data = svc.get_watchlist(codes)
quotes_map = {s["code"]: s for s in quotes_data.get("list", [])}
for item in items:
quote = quotes_map.get(item.code, {})
stocks.append({
"id": item.id,
"code": item.code,
"name": item.name or quote.get("name", ""),
"price": quote.get("price", 0),
"pct": quote.get("pct", 0),
"change": quote.get("change", 0),
"amount": quote.get("amount", 0),
"note": item.note,
"added_at": item.added_at.strftime("%Y-%m-%d")
})
else:
for item in items:
stocks.append({
"id": item.id,
"code": item.code,
"name": item.name,
"note": item.note,
"added_at": item.added_at.strftime("%Y-%m-%d")
})
return {
"ok": True,
"group": {
"id": group.id,
"name": group.name,
"description": group.description,
"color": group.color
},
"stocks": stocks
}
def add_stock_to_group(group_id: int, code: str, note: str = "") -> Dict:
"""添加股票到分组"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
# 检查是否已存在
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
).scalar_one_or_none()
if exists:
return {"ok": False, "msg": "该股票已在分组中"}
# 获取股票名称
sec = s.get(Security, code)
name = sec.name if sec else code
# 获取当前最大排序号
max_order = s.execute(
select(func.max(WatchlistItem.sort_order))
.where(WatchlistItem.group_id == group_id)
).scalar() or 0
item = WatchlistItem(
group_id=group_id,
code=code,
name=name,
note=note,
sort_order=max_order + 1
)
s.add(item)
s.commit()
return {"ok": True, "id": item.id}
def remove_stock_from_group(item_id: int) -> Dict:
"""从分组中移除股票"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
s.delete(item)
s.commit()
return {"ok": True}
def move_stock_to_group(item_id: int, target_group_id: int) -> Dict:
"""将股票移动到另一个分组"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
target_group = s.get(WatchlistGroup, target_group_id)
if not target_group:
return {"ok": False, "msg": "目标分组不存在"}
# 检查目标分组是否已有该股票
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == target_group_id,
WatchlistItem.code == item.code)
).scalar_one_or_none()
if exists:
return {"ok": False, "msg": "目标分组已有该股票"}
item.group_id = target_group_id
s.commit()
return {"ok": True}
def batch_add_stocks(group_id: int, codes: List[str]) -> Dict:
"""批量添加股票到分组"""
with get_session() as s:
group = s.get(WatchlistGroup, group_id)
if not group:
return {"ok": False, "msg": "分组不存在"}
added = 0
skipped = 0
for code in codes:
# 检查是否已存在
exists = s.execute(
select(WatchlistItem)
.where(WatchlistItem.group_id == group_id, WatchlistItem.code == code)
).scalar_one_or_none()
if exists:
skipped += 1
continue
# 获取股票名称
sec = s.get(Security, code)
name = sec.name if sec else code
item = WatchlistItem(
group_id=group_id,
code=code,
name=name,
sort_order=added
)
s.add(item)
added += 1
s.commit()
return {"ok": True, "added": added, "skipped": skipped}
def update_stock_note(item_id: int, note: str) -> Dict:
"""更新股票备注"""
with get_session() as s:
item = s.get(WatchlistItem, item_id)
if not item:
return {"ok": False, "msg": "股票不存在"}
item.note = note
s.commit()
return {"ok": True}
def reorder_stocks(item_ids: List[int]) -> Dict:
"""重新排序分组内的股票"""
with get_session() as s:
for idx, item_id in enumerate(item_ids):
item = s.get(WatchlistItem, item_id)
if item:
item.sort_order = idx
s.commit()
return {"ok": True}
def search_stocks_across_groups(keyword: str) -> List[Dict]:
"""跨分组搜索股票"""
with get_session() as s:
items = s.execute(
select(WatchlistItem, WatchlistGroup)
.join(WatchlistGroup, WatchlistItem.group_id == WatchlistGroup.id)
.where(
(WatchlistItem.code.like(f"%{keyword}%")) |
(WatchlistItem.name.like(f"%{keyword}%"))
)
).all()
result = []
for item, group in items:
result.append({
"id": item.id,
"code": item.code,
"name": item.name,
"group_id": group.id,
"group_name": group.name,
"group_color": group.color,
"note": item.note
})
return result