Files
stock_cursor_v0/backend/data_manager.py
2026-06-15 01:26:39 +08:00

399 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据修正与回填增强:数据修正、断点续传、完整性检查、质量报告"""
import datetime as dt
import json
import os
from typing import List, Optional, Dict
from sqlalchemy import select, func, and_, delete
from db import get_session
from models import DailyQuote, StockMetric, Security, IndexDaily, SectorDaily, JobRun
import ingest
# 回填进度文件路径
PROGRESS_FILE = os.path.join(os.path.dirname(__file__), ".refill_progress.json")
# ============ 数据修正 ============
def delete_quote(code: str, date: str) -> Dict:
"""删除指定股票指定日期的日线数据"""
d = dt.date.fromisoformat(date)
with get_session() as s:
row = s.execute(
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
).scalar_one_or_none()
if not row:
return {"ok": False, "msg": f"{code} {date} 无此数据"}
s.delete(row)
s.commit()
return {"ok": True, "msg": f"已删除 {code} {date} 日线"}
def update_quote(code: str, date: str, fields: Dict) -> Dict:
"""修正指定股票指定日期的日线数据"""
allowed = {"open", "high", "low", "close", "volume", "amount"}
to_update = {k: v for k, v in fields.items() if k in allowed}
if not to_update:
return {"ok": False, "msg": "无有效修正字段"}
d = dt.date.fromisoformat(date)
with get_session() as s:
row = s.execute(
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
).scalar_one_or_none()
if not row:
return {"ok": False, "msg": f"{code} {date} 无此数据"}
for k, v in to_update.items():
setattr(row, k, v)
s.commit()
return {"ok": True, "updated": to_update}
def delete_quotes_range(code: str, start: str, end: str) -> Dict:
"""删除指定股票日期范围内的日线数据"""
d_start = dt.date.fromisoformat(start)
d_end = dt.date.fromisoformat(end)
with get_session() as s:
rows = s.execute(
select(DailyQuote).where(
DailyQuote.code == code,
DailyQuote.date >= d_start,
DailyQuote.date <= d_end
)
).scalars().all()
count = len(rows)
for row in rows:
s.delete(row)
s.commit()
return {"ok": True, "deleted": count, "range": f"{start} ~ {end}"}
def refetch_quote(code: str, days: int = 30) -> Dict:
"""重新抓取指定股票的日线数据(覆盖更新)"""
rows = ingest.fetch_daily(code, days)
if not rows:
return {"ok": False, "msg": f"抓取 {code} 数据失败"}
n = ingest.ingest_quotes([code], days=days)
return {"ok": True, "code": code, "rows": len(rows), "msg": f"已更新 {len(rows)} 条日线"}
# ============ 数据完整性检查 ============
def check_data_integrity(codes: Optional[List[str]] = None, days: int = 30) -> Dict:
"""检查数据完整性,找出缺失数据的股票和日期"""
with get_session() as s:
# 确定检查范围
latest = s.execute(select(func.max(DailyQuote.date))).scalar()
if not latest:
return {"ok": False, "msg": "数据库无日线数据"}
start = latest - dt.timedelta(days=days)
# 获取检查的股票列表
if codes:
check_codes = codes
else:
# 默认检查有记录的所有股票
all_codes = s.execute(
select(DailyQuote.code).where(
DailyQuote.date >= start
).distinct()
).scalars().all()
check_codes = list(all_codes)[:200] # 最多检查200只
# 统计每只股票的数据点数
from sqlalchemy import case
code_counts = {}
for code in check_codes:
count = s.execute(
select(func.count()).select_from(DailyQuote)
.where(DailyQuote.code == code, DailyQuote.date >= start)
).scalar()
code_counts[code] = count
# 以最多数据量为基准(应是交易日数)
expected = max(code_counts.values()) if code_counts else 0
# 找出缺失数据的股票
missing = []
normal = []
for code, count in code_counts.items():
ratio = count / expected if expected > 0 else 0
if ratio < 0.8: # 缺失超过20%
with get_session() as s2:
sec = s2.get(Security, code)
name = sec.name if sec else code
missing.append({
"code": code,
"name": name,
"actual": count,
"expected": expected,
"missing": expected - count,
"missing_pct": round((1 - ratio) * 100, 1)
})
else:
normal.append(code)
missing.sort(key=lambda x: x["missing"], reverse=True)
return {
"ok": True,
"check_range": f"{start.isoformat()} ~ {latest.isoformat()}",
"checked": len(check_codes),
"expected_days": expected,
"normal_count": len(normal),
"missing_count": len(missing),
"missing_stocks": missing[:50]
}
def auto_fix_missing(limit: int = 50) -> Dict:
"""自动补齐缺失数据(批量重新抓取)"""
result = check_data_integrity(days=30)
if not result["ok"] or result["missing_count"] == 0:
return {"ok": True, "msg": "数据完整,无需修复", "fixed": 0}
missing_stocks = result["missing_stocks"][:limit]
codes = [s["code"] for s in missing_stocks]
with get_session() as s:
job = JobRun(job="auto_fix", status="running",
message=f"0/{len(codes)}")
s.add(job)
s.commit()
job_id = job.id
fixed = 0
failed = []
try:
for i, code in enumerate(codes):
rows = ingest.fetch_daily(code, days=60)
if rows:
ingest.ingest_quotes([code], days=60)
fixed += 1
else:
failed.append(code)
if (i + 1) % 10 == 0:
with get_session() as s:
j = s.get(JobRun, job_id)
j.message = f"{i+1}/{len(codes)}"
s.commit()
status = "success"
msg = f"修复 {fixed}/{len(codes)},失败 {len(failed)}"
except Exception as e:
status = "error"
msg = f"修复中断: {repr(e)[:160]}"
with get_session() as s:
j = s.get(JobRun, job_id)
j.status = status
j.finished_at = dt.datetime.now()
j.message = msg
s.commit()
return {"ok": True, "fixed": fixed, "failed": failed, "msg": msg}
# ============ 断点续传回填 ============
def _load_progress() -> Dict:
"""加载回填进度"""
if os.path.exists(PROGRESS_FILE):
try:
with open(PROGRESS_FILE, "r") as f:
return json.load(f)
except Exception:
pass
return {}
def _save_progress(progress: Dict):
"""保存回填进度"""
with open(PROGRESS_FILE, "w") as f:
json.dump(progress, f)
def _clear_progress(task_id: str):
"""清除指定任务的进度"""
progress = _load_progress()
progress.pop(task_id, None)
_save_progress(progress)
def start_refill_with_resume(days: int = 250, task_id: str = "default") -> Dict:
"""带断点续传的全市场回填"""
from akshare_service import _code_name_map
cmap = _code_name_map()
all_codes = [c for c in cmap.keys() if c[:1] in ("0", "3", "6")]
total = len(all_codes)
# 加载进度
progress = _load_progress()
task_progress = progress.get(task_id, {"done_codes": [], "days": days})
done_codes = set(task_progress.get("done_codes", []))
# 过滤已完成的股票
remaining = [c for c in all_codes if c not in done_codes]
with get_session() as s:
job = JobRun(
job="refill_resume",
status="running",
message=f"续传: 已完成 {len(done_codes)}/{total},剩余 {len(remaining)}"
)
s.add(job)
s.commit()
job_id = job.id
fixed = len(done_codes)
try:
for i in range(0, len(remaining), 50):
batch = remaining[i:i + 50]
ingest.ingest_quotes(batch, days=days, with_metrics=True, cmap=cmap)
fixed += len(batch)
# 保存进度
done_codes.update(batch)
progress[task_id] = {
"done_codes": list(done_codes),
"days": days,
"updated_at": dt.datetime.now().isoformat()
}
_save_progress(progress)
with get_session() as s:
j = s.get(JobRun, job_id)
j.message = f"{fixed}/{total}"
s.commit()
# 完成后清除进度
_clear_progress(task_id)
status = "success"
msg = f"完成 {fixed}/{total}"
except Exception as e:
status = "error"
msg = f"中断于 {fixed}/{total} | {repr(e)[:160]}"
# 保留进度供续传
with get_session() as s:
j = s.get(JobRun, job_id)
j.status = status
j.finished_at = dt.datetime.now()
j.message = msg
s.commit()
return {"ok": status == "success", "done": fixed, "total": total, "msg": msg}
def get_refill_progress(task_id: str = "default") -> Dict:
"""获取回填进度"""
progress = _load_progress()
task = progress.get(task_id)
if not task:
return {"ok": True, "has_progress": False, "msg": "无回填进度记录"}
from akshare_service import _code_name_map
cmap = _code_name_map()
total = len([c for c in cmap.keys() if c[:1] in ("0", "3", "6")])
done = len(task.get("done_codes", []))
return {
"ok": True,
"has_progress": True,
"task_id": task_id,
"done": done,
"total": total,
"pct": round(done / total * 100, 1) if total > 0 else 0,
"updated_at": task.get("updated_at", "")
}
def clear_refill_progress(task_id: str = "default") -> Dict:
"""清除回填进度(从头开始)"""
_clear_progress(task_id)
return {"ok": True, "msg": f"已清除任务 {task_id} 的进度"}
# ============ 数据质量报告 ============
def get_data_quality_report() -> Dict:
"""生成数据质量报告"""
with get_session() as s:
# 基本统计
total_quotes = s.execute(select(func.count()).select_from(DailyQuote)).scalar() or 0
total_stocks = s.execute(
select(func.count(DailyQuote.code.distinct()))
).scalar() or 0
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
earliest_date = s.execute(select(func.min(DailyQuote.date))).scalar()
# 最近30天数据密度
if latest_date:
start30 = latest_date - dt.timedelta(days=30)
recent_stocks = s.execute(
select(func.count(DailyQuote.code.distinct()))
.where(DailyQuote.date >= start30)
).scalar() or 0
recent_dates = s.execute(
select(func.count(DailyQuote.date.distinct()))
.where(DailyQuote.date >= start30)
).scalar() or 0
else:
recent_stocks = 0
recent_dates = 0
# 异常数据检测开盘价为0的记录
zero_open = s.execute(
select(func.count()).select_from(DailyQuote)
.where(DailyQuote.open == 0)
).scalar() or 0
# 最近任务状态
recent_jobs = s.execute(
select(JobRun).order_by(JobRun.id.desc()).limit(5)
).scalars().all()
jobs_summary = [{
"job": j.job,
"status": j.status,
"started": j.started_at.strftime("%m-%d %H:%M") if j.started_at else "",
"message": j.message[:100]
} for j in recent_jobs]
# 数据健康度评分
score = 100
issues = []
if zero_open > 0:
score -= min(20, zero_open // 100)
issues.append(f"存在 {zero_open} 条开盘价为0的异常数据")
if total_stocks < 100:
score -= 30
issues.append(f"入库股票数量偏少({total_stocks}只)")
if latest_date and (dt.date.today() - latest_date).days > 7:
score -= 20
issues.append(f"数据滞后 {(dt.date.today() - latest_date).days}")
return {
"ok": True,
"generated_at": dt.datetime.now().isoformat(),
"health_score": max(0, score),
"issues": issues,
"statistics": {
"total_quotes": total_quotes,
"total_stocks": total_stocks,
"latest_date": latest_date.isoformat() if latest_date else None,
"earliest_date": earliest_date.isoformat() if earliest_date else None,
"data_span_days": (latest_date - earliest_date).days if latest_date and earliest_date else 0,
"recent_30d_stocks": recent_stocks,
"recent_30d_dates": recent_dates,
"zero_open_count": zero_open
},
"recent_jobs": jobs_summary
}