jssp_openenv / run.py
Wauplin's picture
Wauplin HF Staff
there was a bug...
6e3f176 verified
import os
from enum import Enum
import typer
from openai import OpenAI
from jssp_openenv.client import JSSPEnvClient
from jssp_openenv.gantt import gantt_chart
from jssp_openenv.policy import JSSPEnvPolicy, JSSPFifoPolicy, JSSPLLMPolicy, JSSPMaxMinPolicy
from jssp_openenv.solver import solve_jssp
SERVER_URL = "http://localhost:8000"
MAX_STEPS = 1000 # Maximum number of steps per instance
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
cli = typer.Typer()
class PolicyName(str, Enum):
FIFO = "fifo"
LLM = "llm"
MAX_MIN = "maxmin"
@cli.command()
def solve(
policy: PolicyName = typer.Argument(help="The policy to use"),
server_url: str = typer.Option(SERVER_URL, help="The URL of the JSSP server"),
max_steps: int = typer.Option(MAX_STEPS, help="The maximum number of steps per instance"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Whether to print verbose output"),
model_id: str = typer.Option(None, "--model-id", "-m", help="The ID of the model to use"),
):
"""Solve a JSSP instance using the given policy."""
env_client = JSSPEnvClient(base_url=server_url)
policy_obj: JSSPEnvPolicy
match policy:
case PolicyName.FIFO:
policy_obj = JSSPFifoPolicy()
title = "FIFO Policy"
filename = "gantt_fifo_policy.png"
case PolicyName.LLM:
if not model_id:
raise ValueError("You must set --model-id to use the LLM policy")
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("You must set the HF_TOKEN environment variable to use the LLM policy")
client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key)
policy_obj = JSSPLLMPolicy(client=client, model_id=model_id)
title = f"LLM Policy ({model_id})"
filename = f"gantt_llm_policy_{model_id.replace('/', '_').replace(':', '_').replace('-', '_').replace(' ', '_')}.png"
case PolicyName.MAX_MIN:
policy_obj = JSSPMaxMinPolicy()
title = "Max-Min Policy"
filename = "gantt_maxmin_policy.png"
makespan, scheduled_events = solve_jssp(env_client, policy_obj, max_steps, verbose)
if verbose:
print("Schedule events:")
for event in scheduled_events:
print(
f"[{event.start_time}] Scheduling job {event.job_id} on machine {event.machine_id} for {event.end_time - event.start_time} minute(s)"
)
print(f"Solved in {makespan} steps")
filepath = os.path.join(OUTPUT_DIR, filename)
gantt_chart(scheduled_events, title=title, makespan=makespan, save_to=filepath)
print(f"Saved Gantt chart to {filepath}")
if __name__ == "__main__":
cli()