File size: 4,334 Bytes
6aaa5e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import copy
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, List
@dataclass
class StepResult:
obs: Dict[str, Any]
reward: float
done: bool
info: Dict[str, Any]
class FlowDebugEnv:
"""
This is a simple environment made for OpenEnv.
Here's what it does:
- It gives you information in text or JSON.
- You can only do one thing: fix the 'inputs.expression' in a 'Condition_Check' step.
- If your fix is exactly right, you win!
"""
def __init__(self, cases: List[Dict[str, Any]], max_attempts: int = 3, seed: Optional[int] = None):
self.cases = cases
self.max_attempts = max_attempts
self.rng = random.Random(seed)
self.current_case: Optional[Dict[str, Any]] = None
self.attempts_left = max_attempts
@classmethod
def from_json(cls, cases_json_path: str, max_attempts: int = 3, seed: Optional[int] = None):
path = Path(cases_json_path)
with open(path, "r", encoding="utf-8") as f:
cases = json.load(f)
return cls(cases=cases, max_attempts=max_attempts, seed=seed)
def reset(self) -> Dict[str, Any]:
self.current_case = copy.deepcopy(self.rng.choice(self.cases))
self.attempts_left = self.max_attempts
return self._make_observation()
def step(self, action: Dict[str, Any]) -> StepResult:
if self.current_case is None:
raise RuntimeError("Call reset() before step().")
self.attempts_left -= 1
if action.get("action") != "patch_step":
return self._invalid_action("Unsupported action type")
step_name = action.get("step")
field = action.get("field")
value = action.get("value")
patched_ok = self._apply_patch(step_name, field, value)
if not patched_ok:
return self._invalid_action("Patch failed (step/field not found)")
gold = self.current_case["gold_fix"]
solved = (step_name == gold["step"] and field == gold["field"] and value == gold["value"])
if solved:
self._mark_success()
obs = self._make_observation(run_status="Succeeded", error=None, failed_step=None)
return StepResult(obs=obs, reward=1.0, done=True,
info={"result": "success", "case_id": self.current_case["case_id"]})
if self.attempts_left <= 0:
obs = self._make_observation()
return StepResult(obs=obs, reward=-0.2, done=True,
info={"result": "out_of_attempts", "case_id": self.current_case["case_id"]})
obs = self._make_observation()
return StepResult(obs=obs, reward=-0.1, done=False,
info={"result": "still_failed", "case_id": self.current_case["case_id"]})
# --------- helpers ----------
def _apply_patch(self, step_name: str, field: str, value: str) -> bool:
for step in self.current_case["steps"]:
if step["name"] == step_name:
if field == "inputs.expression":
step.setdefault("inputs", {})
step["inputs"]["expression"] = value
return True
return False
def _mark_success(self):
for step in self.current_case["steps"]:
step["status"] = "Succeeded"
def _make_observation(self, run_status="Failed", error="keep", failed_step="keep"):
if error == "keep":
err_obj = self.current_case["error"]
else:
err_obj = error
if failed_step == "keep":
failed = self.current_case["failed_step"]
else:
failed = failed_step
return {
"case_id": self.current_case["case_id"],
"run_status": run_status,
"failed_step": failed,
"error": err_obj,
"steps": self.current_case["steps"],
"attempts_left": self.attempts_left
}
def _invalid_action(self, msg: str) -> StepResult:
obs = self._make_observation()
done = (self.attempts_left <= 0)
return StepResult(obs=obs, reward=-0.1, done=done,
info={"result": "invalid_action", "message": msg, "case_id": self.current_case["case_id"]})
|