Spaces:
Sleeping
Sleeping
| import os | |
| from enum import Enum | |
| import typer | |
| from openai import OpenAI | |
| from jssp_openenv.client import JSSPEnvClient | |
| from jssp_openenv.gantt import gantt_chart | |
| from jssp_openenv.policy import JSSPEnvPolicy, JSSPFifoPolicy, JSSPLLMPolicy, JSSPMaxMinPolicy | |
| from jssp_openenv.solver import solve_jssp | |
| SERVER_URL = "http://localhost:8000" | |
| MAX_STEPS = 1000 # Maximum number of steps per instance | |
| OUTPUT_DIR = "output" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| cli = typer.Typer() | |
| class PolicyName(str, Enum): | |
| FIFO = "fifo" | |
| LLM = "llm" | |
| MAX_MIN = "maxmin" | |
| def solve( | |
| policy: PolicyName = typer.Argument(help="The policy to use"), | |
| server_url: str = typer.Option(SERVER_URL, help="The URL of the JSSP server"), | |
| max_steps: int = typer.Option(MAX_STEPS, help="The maximum number of steps per instance"), | |
| verbose: bool = typer.Option(False, "--verbose", "-v", help="Whether to print verbose output"), | |
| model_id: str = typer.Option(None, "--model-id", "-m", help="The ID of the model to use"), | |
| ): | |
| """Solve a JSSP instance using the given policy.""" | |
| env_client = JSSPEnvClient(base_url=server_url) | |
| policy_obj: JSSPEnvPolicy | |
| match policy: | |
| case PolicyName.FIFO: | |
| policy_obj = JSSPFifoPolicy() | |
| title = "FIFO Policy" | |
| filename = "gantt_fifo_policy.png" | |
| case PolicyName.LLM: | |
| if not model_id: | |
| raise ValueError("You must set --model-id to use the LLM policy") | |
| api_key = os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("You must set the HF_TOKEN environment variable to use the LLM policy") | |
| client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key) | |
| policy_obj = JSSPLLMPolicy(client=client, model_id=model_id) | |
| title = f"LLM Policy ({model_id})" | |
| filename = f"gantt_llm_policy_{model_id.replace('/', '_').replace(':', '_').replace('-', '_').replace(' ', '_')}.png" | |
| case PolicyName.MAX_MIN: | |
| policy_obj = JSSPMaxMinPolicy() | |
| title = "Max-Min Policy" | |
| filename = "gantt_maxmin_policy.png" | |
| makespan, scheduled_events = solve_jssp(env_client, policy_obj, max_steps, verbose) | |
| if verbose: | |
| print("Schedule events:") | |
| for event in scheduled_events: | |
| print( | |
| f"[{event.start_time}] Scheduling job {event.job_id} on machine {event.machine_id} for {event.end_time - event.start_time} minute(s)" | |
| ) | |
| print(f"Solved in {makespan} steps") | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| gantt_chart(scheduled_events, title=title, makespan=makespan, save_to=filepath) | |
| print(f"Saved Gantt chart to {filepath}") | |
| if __name__ == "__main__": | |
| cli() | |