Spaces:
Sleeping
Sleeping
| """Test that progress updates are properly isolated between WebSocket clients.""" | |
| import json | |
| import pytest | |
| import time | |
| import threading | |
| import uuid | |
| import websocket | |
| from typing import List, Dict, Any | |
| from comfy_execution.graph_utils import GraphBuilder | |
| from tests.execution.test_execution import ComfyClient | |
| class ProgressTracker: | |
| """Tracks progress messages received by a WebSocket client.""" | |
| def __init__(self, client_id: str): | |
| self.client_id = client_id | |
| self.progress_messages: List[Dict[str, Any]] = [] | |
| self.lock = threading.Lock() | |
| def add_message(self, message: Dict[str, Any]): | |
| """Thread-safe addition of progress messages.""" | |
| with self.lock: | |
| self.progress_messages.append(message) | |
| def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]: | |
| """Get all progress messages for a specific prompt_id.""" | |
| with self.lock: | |
| return [ | |
| msg for msg in self.progress_messages | |
| if msg.get('data', {}).get('prompt_id') == prompt_id | |
| ] | |
| def has_cross_contamination(self, own_prompt_id: str) -> bool: | |
| """Check if this client received progress for other prompts.""" | |
| with self.lock: | |
| for msg in self.progress_messages: | |
| msg_prompt_id = msg.get('data', {}).get('prompt_id') | |
| if msg_prompt_id and msg_prompt_id != own_prompt_id: | |
| return True | |
| return False | |
| class IsolatedClient(ComfyClient): | |
| """Extended ComfyClient that tracks all WebSocket messages.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.progress_tracker = None | |
| self.all_messages: List[Dict[str, Any]] = [] | |
| def connect(self, listen='127.0.0.1', port=8188, client_id=None): | |
| """Connect with a specific client_id and set up message tracking.""" | |
| if client_id is None: | |
| client_id = str(uuid.uuid4()) | |
| super().connect(listen, port, client_id) | |
| self.progress_tracker = ProgressTracker(client_id) | |
| def listen_for_messages(self, duration: float = 5.0): | |
| """Listen for WebSocket messages for a specified duration.""" | |
| end_time = time.time() + duration | |
| self.ws.settimeout(0.5) # Non-blocking with timeout | |
| while time.time() < end_time: | |
| try: | |
| out = self.ws.recv() | |
| if isinstance(out, str): | |
| message = json.loads(out) | |
| self.all_messages.append(message) | |
| # Track progress_state messages | |
| if message.get('type') == 'progress_state': | |
| self.progress_tracker.add_message(message) | |
| except websocket.WebSocketTimeoutException: | |
| continue | |
| except Exception: | |
| # Log error silently in test context | |
| break | |
| class TestProgressIsolation: | |
| """Test suite for verifying progress update isolation between clients.""" | |
| def _server(self, args_pytest): | |
| """Start the ComfyUI server for testing.""" | |
| import subprocess | |
| pargs = [ | |
| 'python', 'main.py', | |
| '--output-directory', args_pytest["output_dir"], | |
| '--listen', args_pytest["listen"], | |
| '--port', str(args_pytest["port"]), | |
| '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', | |
| '--cpu', | |
| ] | |
| p = subprocess.Popen(pargs) | |
| yield | |
| p.kill() | |
| def start_client_with_retry(self, listen: str, port: int, client_id: str = None): | |
| """Start client with connection retries.""" | |
| client = IsolatedClient() | |
| # Connect to server (with retries) | |
| n_tries = 5 | |
| for i in range(n_tries): | |
| time.sleep(4) | |
| try: | |
| client.connect(listen, port, client_id) | |
| return client | |
| except ConnectionRefusedError as e: | |
| print(e) # noqa: T201 | |
| print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 | |
| raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") | |
| def test_progress_isolation_between_clients(self, args_pytest): | |
| """Test that progress updates are isolated between different clients.""" | |
| listen = args_pytest["listen"] | |
| port = args_pytest["port"] | |
| # Create two separate clients with unique IDs | |
| client_a_id = "client_a_" + str(uuid.uuid4()) | |
| client_b_id = "client_b_" + str(uuid.uuid4()) | |
| try: | |
| # Connect both clients with retries | |
| client_a = self.start_client_with_retry(listen, port, client_a_id) | |
| client_b = self.start_client_with_retry(listen, port, client_b_id) | |
| # Create simple workflows for both clients | |
| graph_a = GraphBuilder(prefix="client_a") | |
| image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) | |
| graph_a.node("PreviewImage", images=image_a.out(0)) | |
| graph_b = GraphBuilder(prefix="client_b") | |
| image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) | |
| graph_b.node("PreviewImage", images=image_b.out(0)) | |
| # Submit workflows from both clients | |
| prompt_a = graph_a.finalize() | |
| prompt_b = graph_b.finalize() | |
| response_a = client_a.queue_prompt(prompt_a) | |
| prompt_id_a = response_a['prompt_id'] | |
| response_b = client_b.queue_prompt(prompt_b) | |
| prompt_id_b = response_b['prompt_id'] | |
| # Start threads to listen for messages on both clients | |
| def listen_client_a(): | |
| client_a.listen_for_messages(duration=10.0) | |
| def listen_client_b(): | |
| client_b.listen_for_messages(duration=10.0) | |
| thread_a = threading.Thread(target=listen_client_a) | |
| thread_b = threading.Thread(target=listen_client_b) | |
| thread_a.start() | |
| thread_b.start() | |
| # Wait for threads to complete | |
| thread_a.join() | |
| thread_b.join() | |
| # Verify isolation | |
| # Client A should only receive progress for prompt_id_a | |
| assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \ | |
| f"Client A received progress updates for other clients' workflows. " \ | |
| f"Expected only {prompt_id_a}, but got messages for multiple prompts." | |
| # Client B should only receive progress for prompt_id_b | |
| assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \ | |
| f"Client B received progress updates for other clients' workflows. " \ | |
| f"Expected only {prompt_id_b}, but got messages for multiple prompts." | |
| # Verify each client received their own progress updates | |
| client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a) | |
| client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b) | |
| assert len(client_a_messages) > 0, \ | |
| "Client A did not receive any progress updates for its own workflow" | |
| assert len(client_b_messages) > 0, \ | |
| "Client B did not receive any progress updates for its own workflow" | |
| # Ensure no cross-contamination | |
| client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b) | |
| client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a) | |
| assert len(client_a_other) == 0, \ | |
| f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow" | |
| assert len(client_b_other) == 0, \ | |
| f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow" | |
| finally: | |
| # Clean up connections | |
| if hasattr(client_a, 'ws'): | |
| client_a.ws.close() | |
| if hasattr(client_b, 'ws'): | |
| client_b.ws.close() | |
| def test_progress_with_missing_client_id(self, args_pytest): | |
| """Test that progress updates handle missing client_id gracefully.""" | |
| listen = args_pytest["listen"] | |
| port = args_pytest["port"] | |
| try: | |
| # Connect client with retries | |
| client = self.start_client_with_retry(listen, port) | |
| # Create a simple workflow | |
| graph = GraphBuilder(prefix="test_missing_id") | |
| image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1) | |
| graph.node("PreviewImage", images=image.out(0)) | |
| # Submit workflow | |
| prompt = graph.finalize() | |
| response = client.queue_prompt(prompt) | |
| prompt_id = response['prompt_id'] | |
| # Listen for messages | |
| client.listen_for_messages(duration=5.0) | |
| # Should still receive progress updates for own workflow | |
| messages = client.progress_tracker.get_messages_for_prompt(prompt_id) | |
| assert len(messages) > 0, \ | |
| "Client did not receive progress updates even though it initiated the workflow" | |
| finally: | |
| if hasattr(client, 'ws'): | |
| client.ws.close() | |