|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
REPL Environment Implementation. |
|
|
|
|
|
A Python REPL environment for training language models on code execution tasks, |
|
|
based on the Recursive Language Models (RLM) paradigm. |
|
|
|
|
|
References: |
|
|
- RLM Paper: https://arxiv.org/abs/2512.24601 |
|
|
- Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm |
|
|
- Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/ |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
from collections.abc import Callable |
|
|
from typing import Any, Dict, List, Optional |
|
|
from uuid import uuid4 |
|
|
|
|
|
|
|
|
try: |
|
|
from openenv.core.env_server.interfaces import Environment |
|
|
from openenv.core.env_server.types import EnvironmentMetadata |
|
|
except ImportError: |
|
|
from openenv.core.env_server.interfaces import Environment |
|
|
from openenv.core.env_server.types import EnvironmentMetadata |
|
|
|
|
|
try: |
|
|
from ..models import REPLAction, REPLObservation, REPLState, CodeBlockResult |
|
|
except ImportError: |
|
|
from models import REPLAction, REPLObservation, REPLState, CodeBlockResult |
|
|
|
|
|
try: |
|
|
from .python_executor import PythonExecutor |
|
|
except ImportError: |
|
|
from python_executor import PythonExecutor |
|
|
|
|
|
|
|
|
class REPLEnvironment(Environment): |
|
|
""" |
|
|
A REPL environment for training language models to use code execution. |
|
|
|
|
|
Based on the Recursive Language Models (RLM) paradigm, this environment allows |
|
|
language models to: |
|
|
- Execute Python code in a sandboxed REPL |
|
|
- Work with large contexts loaded as variables |
|
|
- Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern |
|
|
- Optionally make recursive LLM calls via llm_query() / llm_query_batched() |
|
|
|
|
|
Supports two finalization patterns: |
|
|
1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)') |
|
|
2. Prime Intellect style: answer = {"content": "...", "ready": True} |
|
|
|
|
|
Example: |
|
|
>>> env = REPLEnvironment(context="Hello World", task_prompt="Count chars") |
|
|
>>> obs = env.reset() |
|
|
>>> print(obs.context_preview) # "Hello World" |
|
|
>>> |
|
|
>>> obs = env.step(REPLAction(code="result = len(context)")) |
|
|
>>> print(obs.result.success) # True |
|
|
>>> print(obs.available_variables) # ["context", "result", "answer"] |
|
|
>>> |
|
|
>>> obs = env.step(REPLAction(code="print(f'FINAL({result})')")) |
|
|
>>> print(obs.done) # True |
|
|
>>> print(obs.metadata["final_answer"]) # "11" |
|
|
""" |
|
|
|
|
|
SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
context: Optional[str] = None, |
|
|
task_prompt: Optional[str] = None, |
|
|
max_iterations: int = 30, |
|
|
max_output_length: int = 8192, |
|
|
context_preview_length: int = 500, |
|
|
reward_on_success: float = 1.0, |
|
|
reward_on_iteration: float = 0.0, |
|
|
reward_on_failure: float = -0.1, |
|
|
reward_on_error: float = -0.05, |
|
|
llm_query_fn: Optional[Callable[[str], str]] = None, |
|
|
llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None, |
|
|
): |
|
|
"""Initialize the REPL environment. |
|
|
|
|
|
Args: |
|
|
context: Initial context to load (can also be set via REPL_CONTEXT env var) |
|
|
task_prompt: Task description (can also be set via REPL_TASK_PROMPT env var) |
|
|
max_iterations: Maximum steps per episode (default 30, env var REPL_MAX_ITERATIONS) |
|
|
max_output_length: Max chars for stdout/stderr per turn (default 8192) |
|
|
context_preview_length: Chars to show in context preview (default 500) |
|
|
reward_on_success: Reward when final answer is submitted (default 1.0) |
|
|
reward_on_iteration: Reward per iteration step (default 0.0) |
|
|
reward_on_failure: Reward when max iterations reached (default -0.1) |
|
|
reward_on_error: Reward when code execution fails (default -0.05) |
|
|
llm_query_fn: Optional function for llm_query() support |
|
|
llm_batch_fn: Optional function for llm_query_batched() support |
|
|
""" |
|
|
self.initial_context = context or os.environ.get("REPL_CONTEXT", "") |
|
|
self.initial_task_prompt = task_prompt or os.environ.get( |
|
|
"REPL_TASK_PROMPT", "" |
|
|
) |
|
|
self.max_iterations = int( |
|
|
os.environ.get("REPL_MAX_ITERATIONS", max_iterations) |
|
|
) |
|
|
self.max_output_length = max_output_length |
|
|
self.context_preview_length = context_preview_length |
|
|
|
|
|
|
|
|
self.reward_on_success = reward_on_success |
|
|
self.reward_on_iteration = reward_on_iteration |
|
|
self.reward_on_failure = reward_on_failure |
|
|
self.reward_on_error = reward_on_error |
|
|
|
|
|
|
|
|
self.llm_query_fn = llm_query_fn |
|
|
self.llm_batch_fn = llm_batch_fn |
|
|
|
|
|
|
|
|
self._state: Optional[REPLState] = None |
|
|
self._executor: Optional[PythonExecutor] = None |
|
|
|
|
|
def _create_llm_functions( |
|
|
self, |
|
|
hf_token: str, |
|
|
llm_model: Optional[str] = None, |
|
|
) -> None: |
|
|
"""Create LLM functions dynamically using client-provided token. |
|
|
|
|
|
This allows clients to use their own HF token instead of the server's. |
|
|
|
|
|
Security: The token is used only to initialize the InferenceClient |
|
|
and is NOT stored in state, logged, or persisted anywhere. |
|
|
|
|
|
Args: |
|
|
hf_token: HuggingFace API token (not logged or persisted) |
|
|
llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct) |
|
|
""" |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
try: |
|
|
from huggingface_hub import InferenceClient |
|
|
except ImportError: |
|
|
|
|
|
return |
|
|
|
|
|
model = llm_model or os.environ.get( |
|
|
"LLM_MODEL", "Qwen/Qwen3-Coder-480B-A35B-Instruct" |
|
|
) |
|
|
client = InferenceClient(model=model, token=hf_token) |
|
|
|
|
|
def llm_query(prompt: str) -> str: |
|
|
"""Query the LLM with a prompt and return the response.""" |
|
|
try: |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
response = client.chat_completion( |
|
|
messages=messages, |
|
|
max_tokens=2048, |
|
|
temperature=0.7, |
|
|
) |
|
|
return response.choices[0].message.content or "" |
|
|
except Exception as e: |
|
|
return f"Error calling LLM: {e}" |
|
|
|
|
|
def llm_query_batched(prompts: List[str]) -> List[str]: |
|
|
"""Query the LLM with multiple prompts in parallel.""" |
|
|
if not prompts: |
|
|
return [] |
|
|
|
|
|
max_workers = min(len(prompts), 8) |
|
|
results: List[str] = [""] * len(prompts) |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
future_to_idx = { |
|
|
executor.submit(llm_query, prompt): idx |
|
|
for idx, prompt in enumerate(prompts) |
|
|
} |
|
|
for future in as_completed(future_to_idx): |
|
|
idx = future_to_idx[future] |
|
|
try: |
|
|
results[idx] = future.result() |
|
|
except Exception as e: |
|
|
results[idx] = f"Error: {e}" |
|
|
|
|
|
return results |
|
|
|
|
|
self.llm_query_fn = llm_query |
|
|
self.llm_batch_fn = llm_query_batched |
|
|
|
|
|
def reset( |
|
|
self, |
|
|
seed: Optional[int] = None, |
|
|
episode_id: Optional[str] = None, |
|
|
context: Optional[str] = None, |
|
|
task_prompt: Optional[str] = None, |
|
|
hf_token: Optional[str] = None, |
|
|
llm_model: Optional[str] = None, |
|
|
**kwargs: Any, |
|
|
) -> REPLObservation: |
|
|
"""Reset the environment with optional new context. |
|
|
|
|
|
Args: |
|
|
seed: Optional random seed (for reproducibility) |
|
|
episode_id: Optional episode identifier (if not provided, one is generated) |
|
|
context: Context to load (overrides initial_context) |
|
|
task_prompt: Task description (overrides initial_task_prompt) |
|
|
hf_token: Optional HuggingFace token for llm_query/llm_query_batched. |
|
|
If provided, creates LLM functions using this token. |
|
|
Security: Token is NOT stored in state or logged. |
|
|
llm_model: Optional model name for LLM functions (default: from env or Qwen3-Coder) |
|
|
**kwargs: Additional reset parameters |
|
|
|
|
|
Returns: |
|
|
Initial REPLObservation with environment ready message |
|
|
""" |
|
|
effective_context = context or self.initial_context |
|
|
effective_task_prompt = task_prompt or self.initial_task_prompt |
|
|
|
|
|
|
|
|
|
|
|
if not self.llm_query_fn: |
|
|
effective_token = hf_token or os.environ.get("HF_TOKEN") |
|
|
if effective_token: |
|
|
self._create_llm_functions(effective_token, llm_model) |
|
|
|
|
|
|
|
|
self._state = REPLState( |
|
|
episode_id=episode_id or str(uuid4()), |
|
|
step_count=0, |
|
|
context=effective_context, |
|
|
task_prompt=effective_task_prompt, |
|
|
iteration=0, |
|
|
max_iterations=self.max_iterations, |
|
|
namespace_keys=[], |
|
|
final_answer=None, |
|
|
total_execution_time=0.0, |
|
|
) |
|
|
|
|
|
|
|
|
self._executor = PythonExecutor( |
|
|
max_output_length=self.max_output_length |
|
|
) |
|
|
|
|
|
|
|
|
self._executor.set_variable("answer", {"content": "", "ready": False}) |
|
|
|
|
|
|
|
|
if effective_context: |
|
|
self._executor.set_context(effective_context) |
|
|
|
|
|
|
|
|
|
|
|
if self.llm_query_fn: |
|
|
self._executor.inject_function("llm_query", self.llm_query_fn) |
|
|
if self.llm_batch_fn: |
|
|
self._executor.inject_function( |
|
|
"llm_query_batched", self.llm_batch_fn |
|
|
) |
|
|
self._executor.inject_function( |
|
|
"llm_batch", self.llm_batch_fn |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def final_helper(value): |
|
|
"""Helper that returns FINAL(value) string for detection.""" |
|
|
return f"FINAL({value})" |
|
|
|
|
|
self._executor.inject_function("FINAL", final_helper) |
|
|
|
|
|
|
|
|
|
|
|
executor = self._executor |
|
|
|
|
|
def final_var_helper(var_name: str): |
|
|
"""Look up variable by name and return FINAL(value) for detection.""" |
|
|
|
|
|
var_name_clean = str(var_name).strip().strip("\"'") |
|
|
|
|
|
value = executor.get_variable(var_name_clean) |
|
|
if value is not None: |
|
|
return f"FINAL({value})" |
|
|
return ( |
|
|
f"FINAL_VAR({var_name_clean})" |
|
|
) |
|
|
|
|
|
self._executor.inject_function("FINAL_VAR", final_var_helper) |
|
|
|
|
|
|
|
|
self._state.namespace_keys = self._executor.list_variables() |
|
|
|
|
|
|
|
|
message_parts = ["REPL environment initialized."] |
|
|
if effective_context: |
|
|
message_parts.append( |
|
|
f"Context loaded ({len(effective_context)} chars). Use 'context' variable to access it." |
|
|
) |
|
|
if effective_task_prompt: |
|
|
message_parts.append(f"Task: {effective_task_prompt}") |
|
|
message_parts.append( |
|
|
"Use answer['content'] to store your answer, and set answer['ready'] = True when done." |
|
|
) |
|
|
|
|
|
return REPLObservation( |
|
|
result=CodeBlockResult( |
|
|
stdout="\n".join(message_parts), |
|
|
stderr="", |
|
|
locals_snapshot={}, |
|
|
execution_time=0.0, |
|
|
success=True, |
|
|
exception=None, |
|
|
), |
|
|
context_preview=( |
|
|
effective_context[: self.context_preview_length] |
|
|
if effective_context |
|
|
else None |
|
|
), |
|
|
context_length=len(effective_context) if effective_context else 0, |
|
|
available_variables=self._state.namespace_keys, |
|
|
iteration=0, |
|
|
max_iterations=self.max_iterations, |
|
|
done=False, |
|
|
reward=0.0, |
|
|
metadata={ |
|
|
"task_prompt": effective_task_prompt, |
|
|
"message": "Environment ready.", |
|
|
}, |
|
|
) |
|
|
|
|
|
def step( |
|
|
self, |
|
|
action: REPLAction, |
|
|
timeout_s: Optional[float] = None, |
|
|
**kwargs: Any, |
|
|
) -> REPLObservation: |
|
|
"""Execute code and return observation. |
|
|
|
|
|
Args: |
|
|
action: REPLAction containing code to execute |
|
|
timeout_s: Optional timeout in seconds (not currently used) |
|
|
**kwargs: Additional step parameters |
|
|
|
|
|
Returns: |
|
|
REPLObservation with execution results |
|
|
""" |
|
|
if self._state is None or self._executor is None: |
|
|
raise RuntimeError( |
|
|
"Environment not initialized. Call reset() first." |
|
|
) |
|
|
|
|
|
self._state.step_count += 1 |
|
|
self._state.iteration += 1 |
|
|
|
|
|
|
|
|
if action.is_final: |
|
|
self._state.final_answer = action.final_answer or "" |
|
|
return self._create_final_observation( |
|
|
success=True, |
|
|
message="Final answer submitted.", |
|
|
reward=self.reward_on_success, |
|
|
) |
|
|
|
|
|
|
|
|
if self._state.iteration >= self.max_iterations: |
|
|
|
|
|
answer_var = self._executor.get_variable("answer") |
|
|
if isinstance(answer_var, dict) and answer_var.get("content"): |
|
|
self._state.final_answer = str(answer_var.get("content", "")) |
|
|
return self._create_final_observation( |
|
|
success=False, |
|
|
message=f"Maximum iterations ({self.max_iterations}) reached.", |
|
|
reward=self.reward_on_failure, |
|
|
) |
|
|
|
|
|
|
|
|
result = self._executor.execute(action.code) |
|
|
self._state.total_execution_time += result["execution_time"] |
|
|
self._state.namespace_keys = self._executor.list_variables() |
|
|
|
|
|
|
|
|
reward = self.reward_on_iteration |
|
|
if not result["success"]: |
|
|
reward += self.reward_on_error |
|
|
|
|
|
|
|
|
final_answer = self._extract_final_answer(result["stdout"]) |
|
|
done = final_answer is not None |
|
|
|
|
|
if done: |
|
|
self._state.final_answer = final_answer |
|
|
reward = self.reward_on_success |
|
|
|
|
|
return REPLObservation( |
|
|
result=CodeBlockResult( |
|
|
stdout=result["stdout"], |
|
|
stderr=result["stderr"], |
|
|
locals_snapshot=result["locals_snapshot"], |
|
|
execution_time=result["execution_time"], |
|
|
success=result["success"], |
|
|
exception=result["exception"], |
|
|
), |
|
|
context_preview=( |
|
|
self._state.context[: self.context_preview_length] |
|
|
if self._state.context |
|
|
else None |
|
|
), |
|
|
context_length=len(self._state.context) |
|
|
if self._state.context |
|
|
else 0, |
|
|
available_variables=self._state.namespace_keys, |
|
|
iteration=self._state.iteration, |
|
|
max_iterations=self.max_iterations, |
|
|
done=done, |
|
|
reward=reward, |
|
|
metadata={ |
|
|
"task_prompt": self._state.task_prompt, |
|
|
"final_answer": final_answer, |
|
|
"execution_time": result["execution_time"], |
|
|
}, |
|
|
) |
|
|
|
|
|
def _extract_final_answer(self, stdout: str) -> Optional[str]: |
|
|
"""Extract final answer from output. |
|
|
|
|
|
Supports multiple patterns: |
|
|
1. RLM-style: FINAL(answer) in stdout |
|
|
2. RLM-style: FINAL_VAR(variable_name) in stdout |
|
|
3. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace |
|
|
|
|
|
Args: |
|
|
stdout: Standard output from code execution |
|
|
|
|
|
Returns: |
|
|
Final answer string or None if not found |
|
|
""" |
|
|
|
|
|
final_match = re.search(r"FINAL\((.*?)\)", stdout, re.DOTALL) |
|
|
if final_match: |
|
|
return final_match.group(1).strip() |
|
|
|
|
|
|
|
|
final_var_match = re.search(r"FINAL_VAR\((\w+)\)", stdout) |
|
|
if final_var_match and self._executor: |
|
|
var_name = final_var_match.group(1) |
|
|
value = self._executor.get_variable(var_name) |
|
|
if value is not None: |
|
|
return str(value) |
|
|
|
|
|
|
|
|
if self._executor: |
|
|
answer_var = self._executor.get_variable("answer") |
|
|
if isinstance(answer_var, dict): |
|
|
if answer_var.get("ready", False): |
|
|
return str(answer_var.get("content", "")) |
|
|
|
|
|
return None |
|
|
|
|
|
def _create_final_observation( |
|
|
self, success: bool, message: str, reward: float |
|
|
) -> REPLObservation: |
|
|
"""Create observation for episode termination. |
|
|
|
|
|
Args: |
|
|
success: Whether the episode ended successfully |
|
|
message: Termination message |
|
|
reward: Final reward value |
|
|
|
|
|
Returns: |
|
|
Final REPLObservation with done=True |
|
|
""" |
|
|
return REPLObservation( |
|
|
result=CodeBlockResult( |
|
|
stdout=message, |
|
|
stderr="", |
|
|
locals_snapshot={}, |
|
|
execution_time=0.0, |
|
|
success=success, |
|
|
exception=None, |
|
|
), |
|
|
context_preview=None, |
|
|
context_length=0, |
|
|
available_variables=[], |
|
|
iteration=self._state.iteration if self._state else 0, |
|
|
max_iterations=self.max_iterations, |
|
|
done=True, |
|
|
reward=reward, |
|
|
metadata={ |
|
|
"final_answer": self._state.final_answer |
|
|
if self._state |
|
|
else None, |
|
|
"total_execution_time": ( |
|
|
self._state.total_execution_time if self._state else 0 |
|
|
), |
|
|
"total_iterations": self._state.iteration if self._state else 0, |
|
|
}, |
|
|
) |
|
|
|
|
|
@property |
|
|
def state(self) -> REPLState: |
|
|
"""Get the current environment state. |
|
|
|
|
|
Returns: |
|
|
Current REPLState |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If environment not initialized |
|
|
""" |
|
|
if self._state is None: |
|
|
raise RuntimeError( |
|
|
"Environment not initialized. Call reset() first." |
|
|
) |
|
|
return self._state |
|
|
|
|
|
def close(self) -> None: |
|
|
"""Cleanup resources.""" |
|
|
self._executor = None |
|
|
self._state = None |
|
|
|
|
|
def get_metadata(self) -> EnvironmentMetadata: |
|
|
"""Get environment metadata. |
|
|
|
|
|
Returns: |
|
|
EnvironmentMetadata with environment info |
|
|
""" |
|
|
return EnvironmentMetadata( |
|
|
name="repl_env", |
|
|
description="Python REPL environment for RLM-style code execution", |
|
|
version="0.1.0", |
|
|
) |
|
|
|