-
Notifications
You must be signed in to change notification settings - Fork 6.6k
/
Copy pathws.py
96 lines (81 loc) · 3.78 KB
/
ws.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# api/ws.py
import asyncio
import json
from datetime import datetime
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
from ...datamodel import Run, RunStatus
from ..deps import get_db, get_websocket_manager
from ..managers import WebSocketManager
router = APIRouter()
@router.websocket("/runs/{run_id}")
async def run_websocket(
websocket: WebSocket,
run_id: int,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
):
"""WebSocket endpoint for run communication"""
# Verify run exists and is in valid state
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
if not run_response.status or not run_response.data:
logger.warning(f"Run not found: {run_id}")
await websocket.close(code=4004, reason="Run not found")
return
run = run_response.data[0]
if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]:
await websocket.close(code=4003, reason="Run not in valid state")
return
# Connect websocket
connected = await ws_manager.connect(websocket, run_id)
if not connected:
await websocket.close(code=4002, reason="Failed to establish connection")
return
try:
logger.info(f"WebSocket connection established for run {run_id}")
while True:
try:
raw_message = await websocket.receive_text()
message = json.loads(raw_message)
if message.get("type") == "start":
# Handle start message
logger.info(f"Received start request for run {run_id}")
task = message.get("task")
team_config = message.get("team_config")
if task and team_config:
# await ws_manager.start_stream(run_id, task, team_config)
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
else:
logger.warning(f"Invalid start message format for run {run_id}")
await websocket.send_json(
{
"type": "error",
"error": "Invalid start message format",
"timestamp": datetime.utcnow().isoformat(),
}
)
elif message.get("type") == "stop":
logger.info(f"Received stop request for run {run_id}")
reason = message.get("reason") or "User requested stop/cancellation"
await ws_manager.stop_run(run_id, reason=reason)
break
elif message.get("type") == "ping":
await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()})
elif message.get("type") == "input_response":
# Handle input response from client
response = message.get("response")
if response is not None:
await ws_manager.handle_input_response(run_id, response)
else:
logger.warning(f"Invalid input response format for run {run_id}")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {raw_message}")
await websocket.send_json(
{"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()}
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for run {run_id}")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
finally:
await ws_manager.disconnect(run_id)