399 lines
13 KiB
Python
399 lines
13 KiB
Python
"""数据修正与回填增强:数据修正、断点续传、完整性检查、质量报告"""
|
||
import datetime as dt
|
||
import json
|
||
import os
|
||
from typing import List, Optional, Dict
|
||
from sqlalchemy import select, func, and_, delete
|
||
from db import get_session
|
||
from models import DailyQuote, StockMetric, Security, IndexDaily, SectorDaily, JobRun
|
||
import ingest
|
||
|
||
# 回填进度文件路径
|
||
PROGRESS_FILE = os.path.join(os.path.dirname(__file__), ".refill_progress.json")
|
||
|
||
|
||
# ============ 数据修正 ============
|
||
|
||
def delete_quote(code: str, date: str) -> Dict:
|
||
"""删除指定股票指定日期的日线数据"""
|
||
d = dt.date.fromisoformat(date)
|
||
with get_session() as s:
|
||
row = s.execute(
|
||
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
|
||
).scalar_one_or_none()
|
||
if not row:
|
||
return {"ok": False, "msg": f"{code} {date} 无此数据"}
|
||
s.delete(row)
|
||
s.commit()
|
||
return {"ok": True, "msg": f"已删除 {code} {date} 日线"}
|
||
|
||
|
||
def update_quote(code: str, date: str, fields: Dict) -> Dict:
|
||
"""修正指定股票指定日期的日线数据"""
|
||
allowed = {"open", "high", "low", "close", "volume", "amount"}
|
||
to_update = {k: v for k, v in fields.items() if k in allowed}
|
||
if not to_update:
|
||
return {"ok": False, "msg": "无有效修正字段"}
|
||
|
||
d = dt.date.fromisoformat(date)
|
||
with get_session() as s:
|
||
row = s.execute(
|
||
select(DailyQuote).where(DailyQuote.code == code, DailyQuote.date == d)
|
||
).scalar_one_or_none()
|
||
if not row:
|
||
return {"ok": False, "msg": f"{code} {date} 无此数据"}
|
||
for k, v in to_update.items():
|
||
setattr(row, k, v)
|
||
s.commit()
|
||
return {"ok": True, "updated": to_update}
|
||
|
||
|
||
def delete_quotes_range(code: str, start: str, end: str) -> Dict:
|
||
"""删除指定股票日期范围内的日线数据"""
|
||
d_start = dt.date.fromisoformat(start)
|
||
d_end = dt.date.fromisoformat(end)
|
||
with get_session() as s:
|
||
rows = s.execute(
|
||
select(DailyQuote).where(
|
||
DailyQuote.code == code,
|
||
DailyQuote.date >= d_start,
|
||
DailyQuote.date <= d_end
|
||
)
|
||
).scalars().all()
|
||
count = len(rows)
|
||
for row in rows:
|
||
s.delete(row)
|
||
s.commit()
|
||
return {"ok": True, "deleted": count, "range": f"{start} ~ {end}"}
|
||
|
||
|
||
def refetch_quote(code: str, days: int = 30) -> Dict:
|
||
"""重新抓取指定股票的日线数据(覆盖更新)"""
|
||
rows = ingest.fetch_daily(code, days)
|
||
if not rows:
|
||
return {"ok": False, "msg": f"抓取 {code} 数据失败"}
|
||
n = ingest.ingest_quotes([code], days=days)
|
||
return {"ok": True, "code": code, "rows": len(rows), "msg": f"已更新 {len(rows)} 条日线"}
|
||
|
||
|
||
# ============ 数据完整性检查 ============
|
||
|
||
def check_data_integrity(codes: Optional[List[str]] = None, days: int = 30) -> Dict:
|
||
"""检查数据完整性,找出缺失数据的股票和日期"""
|
||
with get_session() as s:
|
||
# 确定检查范围
|
||
latest = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||
if not latest:
|
||
return {"ok": False, "msg": "数据库无日线数据"}
|
||
|
||
start = latest - dt.timedelta(days=days)
|
||
|
||
# 获取检查的股票列表
|
||
if codes:
|
||
check_codes = codes
|
||
else:
|
||
# 默认检查有记录的所有股票
|
||
all_codes = s.execute(
|
||
select(DailyQuote.code).where(
|
||
DailyQuote.date >= start
|
||
).distinct()
|
||
).scalars().all()
|
||
check_codes = list(all_codes)[:200] # 最多检查200只
|
||
|
||
# 统计每只股票的数据点数
|
||
from sqlalchemy import case
|
||
code_counts = {}
|
||
for code in check_codes:
|
||
count = s.execute(
|
||
select(func.count()).select_from(DailyQuote)
|
||
.where(DailyQuote.code == code, DailyQuote.date >= start)
|
||
).scalar()
|
||
code_counts[code] = count
|
||
|
||
# 以最多数据量为基准(应是交易日数)
|
||
expected = max(code_counts.values()) if code_counts else 0
|
||
|
||
# 找出缺失数据的股票
|
||
missing = []
|
||
normal = []
|
||
for code, count in code_counts.items():
|
||
ratio = count / expected if expected > 0 else 0
|
||
if ratio < 0.8: # 缺失超过20%
|
||
with get_session() as s2:
|
||
sec = s2.get(Security, code)
|
||
name = sec.name if sec else code
|
||
missing.append({
|
||
"code": code,
|
||
"name": name,
|
||
"actual": count,
|
||
"expected": expected,
|
||
"missing": expected - count,
|
||
"missing_pct": round((1 - ratio) * 100, 1)
|
||
})
|
||
else:
|
||
normal.append(code)
|
||
|
||
missing.sort(key=lambda x: x["missing"], reverse=True)
|
||
|
||
return {
|
||
"ok": True,
|
||
"check_range": f"{start.isoformat()} ~ {latest.isoformat()}",
|
||
"checked": len(check_codes),
|
||
"expected_days": expected,
|
||
"normal_count": len(normal),
|
||
"missing_count": len(missing),
|
||
"missing_stocks": missing[:50]
|
||
}
|
||
|
||
|
||
def auto_fix_missing(limit: int = 50) -> Dict:
|
||
"""自动补齐缺失数据(批量重新抓取)"""
|
||
result = check_data_integrity(days=30)
|
||
if not result["ok"] or result["missing_count"] == 0:
|
||
return {"ok": True, "msg": "数据完整,无需修复", "fixed": 0}
|
||
|
||
missing_stocks = result["missing_stocks"][:limit]
|
||
codes = [s["code"] for s in missing_stocks]
|
||
|
||
with get_session() as s:
|
||
job = JobRun(job="auto_fix", status="running",
|
||
message=f"0/{len(codes)}")
|
||
s.add(job)
|
||
s.commit()
|
||
job_id = job.id
|
||
|
||
fixed = 0
|
||
failed = []
|
||
try:
|
||
for i, code in enumerate(codes):
|
||
rows = ingest.fetch_daily(code, days=60)
|
||
if rows:
|
||
ingest.ingest_quotes([code], days=60)
|
||
fixed += 1
|
||
else:
|
||
failed.append(code)
|
||
|
||
if (i + 1) % 10 == 0:
|
||
with get_session() as s:
|
||
j = s.get(JobRun, job_id)
|
||
j.message = f"{i+1}/{len(codes)}"
|
||
s.commit()
|
||
|
||
status = "success"
|
||
msg = f"修复 {fixed}/{len(codes)},失败 {len(failed)} 只"
|
||
except Exception as e:
|
||
status = "error"
|
||
msg = f"修复中断: {repr(e)[:160]}"
|
||
|
||
with get_session() as s:
|
||
j = s.get(JobRun, job_id)
|
||
j.status = status
|
||
j.finished_at = dt.datetime.now()
|
||
j.message = msg
|
||
s.commit()
|
||
|
||
return {"ok": True, "fixed": fixed, "failed": failed, "msg": msg}
|
||
|
||
|
||
# ============ 断点续传回填 ============
|
||
|
||
def _load_progress() -> Dict:
|
||
"""加载回填进度"""
|
||
if os.path.exists(PROGRESS_FILE):
|
||
try:
|
||
with open(PROGRESS_FILE, "r") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
pass
|
||
return {}
|
||
|
||
|
||
def _save_progress(progress: Dict):
|
||
"""保存回填进度"""
|
||
with open(PROGRESS_FILE, "w") as f:
|
||
json.dump(progress, f)
|
||
|
||
|
||
def _clear_progress(task_id: str):
|
||
"""清除指定任务的进度"""
|
||
progress = _load_progress()
|
||
progress.pop(task_id, None)
|
||
_save_progress(progress)
|
||
|
||
|
||
def start_refill_with_resume(days: int = 250, task_id: str = "default") -> Dict:
|
||
"""带断点续传的全市场回填"""
|
||
from akshare_service import _code_name_map
|
||
|
||
cmap = _code_name_map()
|
||
all_codes = [c for c in cmap.keys() if c[:1] in ("0", "3", "6")]
|
||
total = len(all_codes)
|
||
|
||
# 加载进度
|
||
progress = _load_progress()
|
||
task_progress = progress.get(task_id, {"done_codes": [], "days": days})
|
||
done_codes = set(task_progress.get("done_codes", []))
|
||
|
||
# 过滤已完成的股票
|
||
remaining = [c for c in all_codes if c not in done_codes]
|
||
|
||
with get_session() as s:
|
||
job = JobRun(
|
||
job="refill_resume",
|
||
status="running",
|
||
message=f"续传: 已完成 {len(done_codes)}/{total},剩余 {len(remaining)}"
|
||
)
|
||
s.add(job)
|
||
s.commit()
|
||
job_id = job.id
|
||
|
||
fixed = len(done_codes)
|
||
try:
|
||
for i in range(0, len(remaining), 50):
|
||
batch = remaining[i:i + 50]
|
||
ingest.ingest_quotes(batch, days=days, with_metrics=True, cmap=cmap)
|
||
fixed += len(batch)
|
||
|
||
# 保存进度
|
||
done_codes.update(batch)
|
||
progress[task_id] = {
|
||
"done_codes": list(done_codes),
|
||
"days": days,
|
||
"updated_at": dt.datetime.now().isoformat()
|
||
}
|
||
_save_progress(progress)
|
||
|
||
with get_session() as s:
|
||
j = s.get(JobRun, job_id)
|
||
j.message = f"{fixed}/{total}"
|
||
s.commit()
|
||
|
||
# 完成后清除进度
|
||
_clear_progress(task_id)
|
||
status = "success"
|
||
msg = f"完成 {fixed}/{total}"
|
||
except Exception as e:
|
||
status = "error"
|
||
msg = f"中断于 {fixed}/{total} | {repr(e)[:160]}"
|
||
# 保留进度供续传
|
||
|
||
with get_session() as s:
|
||
j = s.get(JobRun, job_id)
|
||
j.status = status
|
||
j.finished_at = dt.datetime.now()
|
||
j.message = msg
|
||
s.commit()
|
||
|
||
return {"ok": status == "success", "done": fixed, "total": total, "msg": msg}
|
||
|
||
|
||
def get_refill_progress(task_id: str = "default") -> Dict:
|
||
"""获取回填进度"""
|
||
progress = _load_progress()
|
||
task = progress.get(task_id)
|
||
if not task:
|
||
return {"ok": True, "has_progress": False, "msg": "无回填进度记录"}
|
||
|
||
from akshare_service import _code_name_map
|
||
cmap = _code_name_map()
|
||
total = len([c for c in cmap.keys() if c[:1] in ("0", "3", "6")])
|
||
done = len(task.get("done_codes", []))
|
||
|
||
return {
|
||
"ok": True,
|
||
"has_progress": True,
|
||
"task_id": task_id,
|
||
"done": done,
|
||
"total": total,
|
||
"pct": round(done / total * 100, 1) if total > 0 else 0,
|
||
"updated_at": task.get("updated_at", "")
|
||
}
|
||
|
||
|
||
def clear_refill_progress(task_id: str = "default") -> Dict:
|
||
"""清除回填进度(从头开始)"""
|
||
_clear_progress(task_id)
|
||
return {"ok": True, "msg": f"已清除任务 {task_id} 的进度"}
|
||
|
||
|
||
# ============ 数据质量报告 ============
|
||
|
||
def get_data_quality_report() -> Dict:
|
||
"""生成数据质量报告"""
|
||
with get_session() as s:
|
||
# 基本统计
|
||
total_quotes = s.execute(select(func.count()).select_from(DailyQuote)).scalar() or 0
|
||
total_stocks = s.execute(
|
||
select(func.count(DailyQuote.code.distinct()))
|
||
).scalar() or 0
|
||
latest_date = s.execute(select(func.max(DailyQuote.date))).scalar()
|
||
earliest_date = s.execute(select(func.min(DailyQuote.date))).scalar()
|
||
|
||
# 最近30天数据密度
|
||
if latest_date:
|
||
start30 = latest_date - dt.timedelta(days=30)
|
||
recent_stocks = s.execute(
|
||
select(func.count(DailyQuote.code.distinct()))
|
||
.where(DailyQuote.date >= start30)
|
||
).scalar() or 0
|
||
|
||
recent_dates = s.execute(
|
||
select(func.count(DailyQuote.date.distinct()))
|
||
.where(DailyQuote.date >= start30)
|
||
).scalar() or 0
|
||
else:
|
||
recent_stocks = 0
|
||
recent_dates = 0
|
||
|
||
# 异常数据检测(开盘价为0的记录)
|
||
zero_open = s.execute(
|
||
select(func.count()).select_from(DailyQuote)
|
||
.where(DailyQuote.open == 0)
|
||
).scalar() or 0
|
||
|
||
# 最近任务状态
|
||
recent_jobs = s.execute(
|
||
select(JobRun).order_by(JobRun.id.desc()).limit(5)
|
||
).scalars().all()
|
||
|
||
jobs_summary = [{
|
||
"job": j.job,
|
||
"status": j.status,
|
||
"started": j.started_at.strftime("%m-%d %H:%M") if j.started_at else "",
|
||
"message": j.message[:100]
|
||
} for j in recent_jobs]
|
||
|
||
# 数据健康度评分
|
||
score = 100
|
||
issues = []
|
||
|
||
if zero_open > 0:
|
||
score -= min(20, zero_open // 100)
|
||
issues.append(f"存在 {zero_open} 条开盘价为0的异常数据")
|
||
|
||
if total_stocks < 100:
|
||
score -= 30
|
||
issues.append(f"入库股票数量偏少({total_stocks}只)")
|
||
|
||
if latest_date and (dt.date.today() - latest_date).days > 7:
|
||
score -= 20
|
||
issues.append(f"数据滞后 {(dt.date.today() - latest_date).days} 天")
|
||
|
||
return {
|
||
"ok": True,
|
||
"generated_at": dt.datetime.now().isoformat(),
|
||
"health_score": max(0, score),
|
||
"issues": issues,
|
||
"statistics": {
|
||
"total_quotes": total_quotes,
|
||
"total_stocks": total_stocks,
|
||
"latest_date": latest_date.isoformat() if latest_date else None,
|
||
"earliest_date": earliest_date.isoformat() if earliest_date else None,
|
||
"data_span_days": (latest_date - earliest_date).days if latest_date and earliest_date else 0,
|
||
"recent_30d_stocks": recent_stocks,
|
||
"recent_30d_dates": recent_dates,
|
||
"zero_open_count": zero_open
|
||
},
|
||
"recent_jobs": jobs_summary
|
||
}
|