| | from __future__ import annotations |
| |
|
| | import argparse |
| | import logging |
| | import os |
| | from typing import Any, Dict, Optional |
| |
|
| | from tqdm import tqdm |
| |
|
| | from edgeeda.config import load_config, Config |
| | from edgeeda.utils import seed_everything, ensure_dir |
| | from edgeeda.store import TrialStore, TrialRecord |
| | from edgeeda.orfs.runner import ORFSRunner |
| | from edgeeda.orfs.metrics import find_best_metadata_json, load_json |
| | from edgeeda.reward import compute_reward |
| | from edgeeda.viz import export_trials, make_plots |
| |
|
| | from edgeeda.agents.random_search import RandomSearchAgent |
| | from edgeeda.agents.successive_halving import SuccessiveHalvingAgent |
| | from edgeeda.agents.surrogate_ucb import SurrogateUCBAgent |
| |
|
| |
|
| | AGENTS = { |
| | "random": RandomSearchAgent, |
| | "successive_halving": SuccessiveHalvingAgent, |
| | "surrogate_ucb": SurrogateUCBAgent, |
| | } |
| |
|
| |
|
| | def _select_agent(cfg: Config): |
| | name = cfg.tuning.agent |
| | if name not in AGENTS: |
| | raise ValueError(f"Unknown agent: {name}. Choose from {list(AGENTS.keys())}") |
| | return AGENTS[name](cfg) |
| |
|
| |
|
| | def _setup_logging(cfg: Config) -> None: |
| | """Setup logging to both file and console.""" |
| | log_dir = cfg.experiment.out_dir |
| | ensure_dir(log_dir) |
| | log_file = os.path.join(log_dir, "tuning.log") |
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| | handlers=[ |
| | logging.FileHandler(log_file), |
| | logging.StreamHandler() |
| | ] |
| | ) |
| | logging.info(f"Logging initialized. Log file: {log_file}") |
| |
|
| |
|
| | def cmd_tune(args: argparse.Namespace) -> None: |
| | cfg = load_config(args.config) |
| | if args.budget is not None: |
| | cfg.tuning.budget.total_actions = int(args.budget) |
| |
|
| | seed_everything(cfg.experiment.seed) |
| | ensure_dir(cfg.experiment.out_dir) |
| | _setup_logging(cfg) |
| |
|
| | logging.info(f"Starting tuning experiment: {cfg.experiment.name}") |
| | logging.info(f"Agent: {cfg.tuning.agent}, Budget: {cfg.tuning.budget.total_actions} actions") |
| | logging.info(f"Platform: {cfg.design.platform}, Design: {cfg.design.design}") |
| |
|
| | orfs_flow_dir = cfg.experiment.orfs_flow_dir or os.environ.get("ORFS_FLOW_DIR") |
| | if not orfs_flow_dir: |
| | raise RuntimeError("ORFS flow dir missing. Set experiment.orfs_flow_dir or export ORFS_FLOW_DIR=/path/to/ORFS/flow") |
| |
|
| | logging.info(f"ORFS flow directory: {orfs_flow_dir}") |
| |
|
| | runner = ORFSRunner(orfs_flow_dir) |
| | store = TrialStore(cfg.experiment.db_path) |
| | agent = _select_agent(cfg) |
| |
|
| | expensive_set = set(cfg.flow.fidelities[-1:]) |
| | expensive_used = 0 |
| |
|
| | for i in tqdm(range(cfg.tuning.budget.total_actions), desc="actions"): |
| | action = agent.propose() |
| | fidelity = action.fidelity |
| |
|
| | |
| | if fidelity in expensive_set and expensive_used >= cfg.tuning.budget.max_expensive: |
| | |
| | fidelity = cfg.flow.fidelities[0] |
| | action = type(action)(variant=action.variant, fidelity=fidelity, knobs=action.knobs) |
| |
|
| | make_target = cfg.flow.targets.get(fidelity, fidelity) |
| | logging.info(f"Action {i+1}/{cfg.tuning.budget.total_actions}: variant={action.variant}, " |
| | f"fidelity={action.fidelity}, knobs={action.knobs}") |
| | |
| | |
| | logging.debug(f"Running: {make_target} for variant {action.variant}") |
| | rr = runner.run_make( |
| | target=make_target, |
| | design_config=cfg.design.design_config, |
| | flow_variant=action.variant, |
| | overrides={k: str(v) for k, v in action.knobs.items()}, |
| | timeout_sec=args.timeout, |
| | ) |
| |
|
| | ok = (rr.return_code == 0) |
| | if not ok: |
| | logging.warning(f"Trial {i+1} failed: variant={action.variant}, return_code={rr.return_code}") |
| | logging.debug(f"Command: {rr.cmd}") |
| | if rr.stderr: |
| | logging.debug(f"Stderr (last 500 chars): {rr.stderr[-500:]}") |
| | else: |
| | logging.info(f"Trial {i+1} succeeded: variant={action.variant}, runtime={rr.runtime_sec:.2f}s") |
| |
|
| | if fidelity in expensive_set: |
| | expensive_used += 1 |
| |
|
| | |
| | meta_target = ( |
| | cfg.flow.targets.get("metadata_generate") |
| | or cfg.flow.targets.get("metadata-generate") |
| | or cfg.flow.targets.get("metadata", "metadata") |
| | ) |
| | if meta_target == "metadata": |
| | meta_target = "metadata-generate" |
| | logging.debug(f"Generating metadata for variant {action.variant} using target={meta_target}") |
| | meta_result = runner.run_make( |
| | target=meta_target, |
| | design_config=cfg.design.design_config, |
| | flow_variant=action.variant, |
| | overrides={}, |
| | timeout_sec=args.timeout, |
| | ) |
| | if meta_result.return_code != 0: |
| | logging.warning(f"Metadata generation failed for variant {action.variant}: return_code={meta_result.return_code}") |
| |
|
| | meta_path = find_best_metadata_json( |
| | orfs_flow_dir=orfs_flow_dir, |
| | platform=cfg.design.platform, |
| | design=cfg.design.design, |
| | variant=action.variant, |
| | ) |
| |
|
| | reward = None |
| | flat = None |
| |
|
| | if meta_path: |
| | logging.debug(f"Found metadata at: {meta_path}") |
| | try: |
| | mobj = load_json(meta_path) |
| | reward, comps, flat = compute_reward( |
| | metrics_obj=mobj, |
| | wns_candidates=cfg.reward.wns_candidates, |
| | area_candidates=cfg.reward.area_candidates, |
| | power_candidates=cfg.reward.power_candidates, |
| | weights=cfg.reward.weights, |
| | ) |
| | if reward is not None: |
| | logging.info(f"Computed reward for variant {action.variant}: {reward:.4f} " |
| | f"(WNS={comps.wns}, area={comps.area}, power={comps.power})") |
| | else: |
| | logging.warning(f"Reward computation returned None for variant {action.variant}") |
| | except Exception as e: |
| | logging.error(f"Failed to compute reward for variant {action.variant}: {e}", exc_info=True) |
| | ok = False |
| | else: |
| | logging.warning(f"Metadata not found for variant {action.variant} at " |
| | f"reports/{cfg.design.platform}/{cfg.design.design}/{action.variant}/") |
| |
|
| | store.add( |
| | TrialRecord( |
| | exp_name=cfg.experiment.name, |
| | platform=cfg.design.platform, |
| | design=cfg.design.design, |
| | variant=action.variant, |
| | fidelity=action.fidelity, |
| | knobs=action.knobs, |
| | make_cmd=rr.cmd, |
| | return_code=rr.return_code, |
| | runtime_sec=rr.runtime_sec, |
| | reward=reward, |
| | metrics=flat, |
| | metadata_path=meta_path, |
| | ) |
| | ) |
| |
|
| | agent.observe(action, ok=ok, reward=reward, metrics_flat=flat) |
| |
|
| | store.close() |
| |
|
| | |
| | logging.info("Exporting trial summary...") |
| | df = export_trials(cfg.experiment.db_path) |
| | out_csv = os.path.join(cfg.experiment.out_dir, "summary.csv") |
| | df.to_csv(out_csv, index=False) |
| | |
| | |
| | total_trials = len(df) |
| | successful = len(df[df['return_code'] == 0]) |
| | with_rewards = len(df[df['reward'].notna()]) |
| | logging.info(f"Experiment complete: {total_trials} trials, {successful} successful, {with_rewards} with rewards") |
| | |
| | print(f"[done] wrote {out_csv}") |
| |
|
| |
|
| | def cmd_analyze(args: argparse.Namespace) -> None: |
| | df = export_trials(args.db) |
| | ensure_dir(args.out) |
| | df.to_csv(os.path.join(args.out, "trials.csv"), index=False) |
| | make_plots(df, args.out) |
| | print(f"[done] wrote plots to {args.out}") |
| |
|
| |
|
| | def main() -> None: |
| | p = argparse.ArgumentParser(prog="edgeeda") |
| | sub = p.add_subparsers(dest="cmd", required=True) |
| |
|
| | p_tune = sub.add_parser("tune", help="Run agentic tuning loop on ORFS") |
| | p_tune.add_argument("--config", required=True, help="YAML config") |
| | p_tune.add_argument("--budget", type=int, default=None, help="Override total_actions") |
| | p_tune.add_argument("--timeout", type=int, default=None, help="Timeout per make run (sec)") |
| | p_tune.set_defaults(func=cmd_tune) |
| |
|
| | p_an = sub.add_parser("analyze", help="Export CSV + plots") |
| | p_an.add_argument("--db", required=True, help="SQLite db path") |
| | p_an.add_argument("--out", required=True, help="Output directory for plots") |
| | p_an.set_defaults(func=cmd_analyze) |
| |
|
| | args = p.parse_args() |
| | args.func(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|