File size: 4,334 Bytes
6aaa5e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import copy
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, List


@dataclass
class StepResult:
    obs: Dict[str, Any]
    reward: float
    done: bool
    info: Dict[str, Any]


class FlowDebugEnv:
    """
    This is a simple environment made for OpenEnv.
    Here's what it does:
    - It gives you information in text or JSON.
    - You can only do one thing: fix the 'inputs.expression' in a 'Condition_Check' step.
    - If your fix is exactly right, you win!
    """
    def __init__(self, cases: List[Dict[str, Any]], max_attempts: int = 3, seed: Optional[int] = None):
        self.cases = cases
        self.max_attempts = max_attempts
        self.rng = random.Random(seed)
        self.current_case: Optional[Dict[str, Any]] = None
        self.attempts_left = max_attempts

    @classmethod
    def from_json(cls, cases_json_path: str, max_attempts: int = 3, seed: Optional[int] = None):
        path = Path(cases_json_path)
        with open(path, "r", encoding="utf-8") as f:
            cases = json.load(f)
        return cls(cases=cases, max_attempts=max_attempts, seed=seed)

    def reset(self) -> Dict[str, Any]:
        self.current_case = copy.deepcopy(self.rng.choice(self.cases))
        self.attempts_left = self.max_attempts
        return self._make_observation()

    def step(self, action: Dict[str, Any]) -> StepResult:
        if self.current_case is None:
            raise RuntimeError("Call reset() before step().")

        self.attempts_left -= 1

        if action.get("action") != "patch_step":
            return self._invalid_action("Unsupported action type")

        step_name = action.get("step")
        field = action.get("field")
        value = action.get("value")

        patched_ok = self._apply_patch(step_name, field, value)
        if not patched_ok:
            return self._invalid_action("Patch failed (step/field not found)")

        gold = self.current_case["gold_fix"]
        solved = (step_name == gold["step"] and field == gold["field"] and value == gold["value"])

        if solved:
            self._mark_success()
            obs = self._make_observation(run_status="Succeeded", error=None, failed_step=None)
            return StepResult(obs=obs, reward=1.0, done=True,
                              info={"result": "success", "case_id": self.current_case["case_id"]})

        if self.attempts_left <= 0:
            obs = self._make_observation()
            return StepResult(obs=obs, reward=-0.2, done=True,
                              info={"result": "out_of_attempts", "case_id": self.current_case["case_id"]})

        obs = self._make_observation()
        return StepResult(obs=obs, reward=-0.1, done=False,
                          info={"result": "still_failed", "case_id": self.current_case["case_id"]})

    # --------- helpers ----------
    def _apply_patch(self, step_name: str, field: str, value: str) -> bool:
        for step in self.current_case["steps"]:
            if step["name"] == step_name:
                if field == "inputs.expression":
                    step.setdefault("inputs", {})
                    step["inputs"]["expression"] = value
                    return True
        return False

    def _mark_success(self):
        for step in self.current_case["steps"]:
            step["status"] = "Succeeded"

    def _make_observation(self, run_status="Failed", error="keep", failed_step="keep"):
        if error == "keep":
            err_obj = self.current_case["error"]
        else:
            err_obj = error

        if failed_step == "keep":
            failed = self.current_case["failed_step"]
        else:
            failed = failed_step

        return {
            "case_id": self.current_case["case_id"],
            "run_status": run_status,
            "failed_step": failed,
            "error": err_obj,
            "steps": self.current_case["steps"],
            "attempts_left": self.attempts_left
        }

    def _invalid_action(self, msg: str) -> StepResult:
        obs = self._make_observation()
        done = (self.attempts_left <= 0)
        return StepResult(obs=obs, reward=-0.1, done=done,
                          info={"result": "invalid_action", "message": msg, "case_id": self.current_case["case_id"]})