import asyncio import json import logging import os from aiohttp import web from meshbot.database import Database logger = logging.getLogger(__name__) STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") class WebSocketManager: def __init__(self): self.clients: set[web.WebSocketResponse] = set() async def broadcast(self, msg_type: str, data: dict | list): message = json.dumps({"type": msg_type, "data": data}) closed = set() for ws in self.clients: try: await ws.send_str(message) except Exception: closed.add(ws) self.clients -= closed class WebServer: def __init__(self, db: Database, ws_manager: WebSocketManager): self.db = db self.ws_manager = ws_manager self.app = web.Application() self._setup_routes() def _setup_routes(self): self.app.router.add_get("/ws", self._ws_handler) self.app.router.add_get("/api/nodes", self._api_nodes) self.app.router.add_get("/api/messages", self._api_messages) self.app.router.add_get("/api/stats", self._api_stats) self.app.router.add_get("/map", self._serve_map) self.app.router.add_get("/", self._serve_index) self.app.router.add_static("/static", STATIC_DIR) async def _ws_handler(self, request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) self.ws_manager.clients.add(ws) logger.info("WebSocket client connected (%d total)", len(self.ws_manager.clients)) try: # Send initial data nodes = await self.db.get_all_nodes() await ws.send_str(json.dumps({"type": "initial", "data": nodes})) stats = await self.db.get_stats() await ws.send_str(json.dumps({"type": "stats_update", "data": stats})) async for msg in ws: pass # We only send, not receive finally: self.ws_manager.clients.discard(ws) logger.info("WebSocket client disconnected (%d remaining)", len(self.ws_manager.clients)) return ws async def _api_nodes(self, request: web.Request) -> web.Response: nodes = await self.db.get_all_nodes() return web.json_response(nodes) async def _api_messages(self, request: web.Request) -> web.Response: limit = int(request.query.get("limit", "50")) messages = await self.db.get_recent_messages(limit) return web.json_response(messages) async def _api_stats(self, request: web.Request) -> web.Response: stats = await self.db.get_stats() return web.json_response(stats) async def _serve_index(self, request: web.Request) -> web.Response: return web.FileResponse(os.path.join(STATIC_DIR, "index.html")) async def _serve_map(self, request: web.Request) -> web.Response: return web.FileResponse(os.path.join(STATIC_DIR, "map.html")) async def start(self, host: str, port: int): runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, host, port) await site.start() logger.info("Webserver started at http://%s:%d", host, port) return runner