59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
"""
|
|
WebSocket connection manager — broadcasts real-time stock quotes to subscribers.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
from collections import defaultdict
|
|
from fastapi import WebSocket
|
|
from loguru import logger
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
# symbol -> set of WebSocket connections
|
|
self._subs: dict[str, set[WebSocket]] = defaultdict(set)
|
|
self._all: set[WebSocket] = set()
|
|
|
|
async def connect(self, ws: WebSocket, symbol: str | None = None):
|
|
await ws.accept()
|
|
self._all.add(ws)
|
|
if symbol:
|
|
self._subs[symbol].add(ws)
|
|
logger.info(f"WS connected symbol={symbol}, total={len(self._all)}")
|
|
|
|
def disconnect(self, ws: WebSocket, symbol: str | None = None):
|
|
self._all.discard(ws)
|
|
if symbol:
|
|
self._subs[symbol].discard(ws)
|
|
else:
|
|
for s in list(self._subs.keys()):
|
|
self._subs[s].discard(ws)
|
|
logger.info(f"WS disconnected, total={len(self._all)}")
|
|
|
|
async def broadcast_quote(self, symbol: str, data: dict):
|
|
"""Send quote update to all subscribers of a specific symbol."""
|
|
message = json.dumps({"type": "quote", "symbol": symbol, "data": data})
|
|
dead = set()
|
|
for ws in list(self._subs.get(symbol, [])):
|
|
try:
|
|
await ws.send_text(message)
|
|
except Exception:
|
|
dead.add(ws)
|
|
for ws in dead:
|
|
self.disconnect(ws, symbol)
|
|
|
|
async def broadcast_all(self, data: list[dict]):
|
|
"""Broadcast heatmap snapshot to all connected clients."""
|
|
message = json.dumps({"type": "heatmap", "data": data})
|
|
dead = set()
|
|
for ws in list(self._all):
|
|
try:
|
|
await ws.send_text(message)
|
|
except Exception:
|
|
dead.add(ws)
|
|
for ws in dead:
|
|
self.disconnect(ws)
|
|
|
|
|
|
manager = ConnectionManager()
|