#!/usr/bin/env python3 """ Concurrency test for benchmark environment using WebSockets. Each WebSocket connection gets its own dedicated environment session, enabling true concurrent execution across multiple sessions. Run the server first: cd benchmark && uvicorn server.app:app --port 8000 Then run this script: python test_concurrency.py --requests 100 --wait 1.0 python test_concurrency.py -n 100 -w 1 --url wss://your-server.hf.space """ import argparse import asyncio import json import time from dataclasses import dataclass try: import websockets except ImportError: print("Install websockets: pip install websockets") raise @dataclass class RequestResult: """Result from a single WebSocket request.""" request_id: int wait_requested: float waited_seconds: float elapsed: float pid: int session_hash: str host_url: str def convert_to_ws_url(url: str) -> str: """Convert HTTP URL to WebSocket URL.""" url = url.rstrip("/") if url.startswith("http://"): return "ws://" + url[7:] + "/ws" elif url.startswith("https://"): return "wss://" + url[8:] + "/ws" elif url.startswith("ws://") or url.startswith("wss://"): return url if url.endswith("/ws") else url + "/ws" return "ws://" + url + "/ws" async def ws_session( ws_url: str, request_id: int, wait_seconds: float, timeout: float = 60.0, ) -> RequestResult: """ Run a complete WebSocket session: connect, reset, step, close. Each session gets its own environment instance on the server. """ start = time.perf_counter() async with websockets.connect(ws_url, open_timeout=timeout) as ws: # Reset to initialize environment await ws.send(json.dumps({"type": "reset", "data": {}})) reset_response = json.loads(await asyncio.wait_for(ws.recv(), timeout)) if reset_response.get("type") == "error": raise RuntimeError(f"Reset error: {reset_response}") # Step with wait time await ws.send( json.dumps({ "type": "step", "data": {"wait_seconds": wait_seconds}, }) ) step_response = json.loads(await asyncio.wait_for(ws.recv(), timeout)) if step_response.get("type") == "error": raise RuntimeError(f"Step error: {step_response}") # Close cleanly await ws.send(json.dumps({"type": "close"})) elapsed = time.perf_counter() - start obs = step_response.get("data", {}).get("observation", {}) return RequestResult( request_id=request_id, wait_requested=wait_seconds, waited_seconds=obs.get("waited_seconds", 0.0), elapsed=elapsed, pid=obs.get("pid", 0), session_hash=obs.get("session_hash", ""), host_url=obs.get("host_url", ""), ) async def run_concurrent_test( url: str, num_requests: int, wait_seconds: float, timeout: float = 120.0, ) -> dict: """Run concurrent WebSocket sessions and collect results.""" ws_url = convert_to_ws_url(url) print(f"WebSocket URL: {ws_url}") print(f"Running {num_requests} concurrent sessions, each waiting {wait_seconds}s...") print() start = time.perf_counter() # Launch all sessions concurrently tasks = [ ws_session(ws_url, i, wait_seconds, timeout) for i in range(num_requests) ] results = await asyncio.gather(*tasks, return_exceptions=True) total_time = time.perf_counter() - start # Process results successful = [r for r in results if isinstance(r, RequestResult)] failed = [r for r in results if isinstance(r, Exception)] if not successful: print("All requests failed!") for i, err in enumerate(failed[:5]): print(f" Error {i}: {err}") return {"error": "All requests failed"} avg_time = sum(r.elapsed for r in successful) / len(successful) unique_pids = set(r.pid for r in successful) unique_sessions = set(r.session_hash for r in successful) unique_hosts = set(r.host_url for r in successful) return { "num_requests": num_requests, "successful": len(successful), "failed": len(failed), "wait_seconds": wait_seconds, "total_time": total_time, "avg_time": avg_time, "unique_pids": len(unique_pids), "unique_sessions": len(unique_sessions), "unique_hosts": len(unique_hosts), "pids": list(unique_pids)[:10], # First 10 for display } async def main(): parser = argparse.ArgumentParser( description="Test benchmark environment concurrency via WebSocket" ) parser.add_argument( "--requests", "-n", type=int, default=10, help="Number of concurrent WebSocket sessions" ) parser.add_argument( "--wait", "-w", type=float, default=1.0, help="Wait time per request (seconds)" ) parser.add_argument( "--url", "-u", type=str, default="http://localhost:8000", help="Server URL (http/https/ws/wss)" ) parser.add_argument( "--timeout", "-t", type=float, default=120.0, help="Timeout per request (seconds)" ) args = parser.parse_args() result = await run_concurrent_test( args.url, args.requests, args.wait, args.timeout ) if "error" in result: return print(f"Successful: {result['successful']}/{result['num_requests']}") if result["failed"]: print(f"Failed: {result['failed']}") print(f"Total time: {result['total_time']:.3f}s") print(f"Avg time: {result['avg_time']:.3f}s") print(f"Unique PIDs: {result['unique_pids']}") print(f"Unique sessions: {result['unique_sessions']}") print(f"Unique hosts: {result['unique_hosts']}") # Calculate concurrency metrics ideal_time = args.wait actual_concurrency = (args.requests * args.wait) / result["total_time"] print() print(f"Ideal time (full concurrency): {ideal_time:.3f}s") print(f"Effective concurrency: {actual_concurrency:.1f}x") if __name__ == "__main__": asyncio.run(main())