Image-to-Image
Transformers

EARL: The Promise of RL for Autoregressive Image Editing

Official model for the paper The Promise of RL for Autoregressive Image Editing.

arXiv Code Models

EARL Teaser

Abstract

We explore three strategies to enhance performance on a wide range of image editing tasks: supervised fine-tuning (SFT), reinforcement learning (RL), and Chain-of-Thought (CoT) reasoning. In order to study all these components in one consistent framework, we adopt an autoregressive multimodal model that processes textual and visual tokens in a unified manner. We find RL combined with a large multi-modal LLM verifier to be the most effective of these strategies. As a result, we release EARL: Editing with Autoregression and RL, a strong RL-based image editing model that performs competitively on a diverse range of edits compared to strong baselines, despite using much less training data. Thus, EARL pushes the frontier of autoregressive multimodal models on image editing. We release our code, training data, and trained models at this https URL .

Overview

EARL (Editing with Autoregression and RL) introduces a novel approach to image editing using an autoregressive multimodal model. It processes textual and visual tokens in a unified manner and leverages reinforcement learning combined with a large multi-modal LLM verifier to achieve strong performance across various image editing tasks. The model is designed for efficiency, using significantly less training data than comparable baselines, and pushes the frontier of autoregressive multimodal models on image editing.

Usage

You can quickly try the model using vLLM for inference.

First, clone the official repository and install the prerequisites:

git clone https://github.com/saba96/EARL.git
cd EARL
python -m venv /path/to/envs/EARL
. /path/to/envs/EARL/bin/activate
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
pip install vllm==0.8.4
pip install flash-attn==2.7.4.post1 --no-build-isolation
pip install -r requirements.txt
export PYTHONPATH=$(pwd)

Patch vLLM to support Emu3: This is a critical step. You need to edit the registry.py file in your vLLM installation.

vim /path/to/venv/lib/python3.10/site-packages/vllm/model_executor/models/registry.py

Add the following line to the _MULTIMODAL_MODELS dictionary around line 166:

_MULTIMODAL_MODELS = {    
    # add this line
    "Emu3ForCausalLM": ("llama", "LlamaForCausalLM"), 
    # end of adding
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), # already exists
    # ... other models
}

Then, run inference using the following Python code snippet. Ensure you have an image file ready (e.g., ./examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png from the original repository).

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
from vllm import LLM, ModelRegistry, SamplingParams

# Ensure Emu3ForCausalLM is available or registered.
# If you cloned the repo, it should be importable from emu3.model.modeling_emu3_vllm
# For demonstration, we'll assume it's correctly handled by trust_remote_code or local setup.
# If you face issues, ensure the model's specific class is registered with vLLM's ModelRegistry.
# Example: from emu3.model.modeling_emu3_vllm import Emu3ForCausalLM
# ModelRegistry.register_model("Emu3ForCausalLM", Emu3ForCausalLM)


# --- Helper functions from original repo for image preprocessing ---
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values
# -------------------------------------------------------------------

# Load the model with vLLM
path = 'Image-editing/imged_rl_grpo_sft.s_rl.sc' # Model ID from Hugging Face Hub
llm = LLM(
    model=path,
    trust_remote_code=True,
    dtype="auto", # or torch.bfloat16 if supported by your hardware
    gpu_memory_utilization=0.9,
    # Additional vLLM specific arguments if needed
)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

# Prepare inputs
image_path = './examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png' # Replace with a path to your image
# The `load_image` function prepares the pixel values as expected by the model.
pixel_values = load_image(image_path, max_num=6).to(torch.bfloat16).cuda() # Ensure image is loaded and moved to GPU

# Format the prompt
question = "Edit the image: change the color of the car to red."
prompt = f"A chat between a curious user and an AI assistant.
USER: <image>
{question} ASSISTANT:"

sampling_params = SamplingParams(max_tokens=512, temperature=0.7) # Adjust as needed

# In vLLM, for multimodal models, the image input might be handled internally
# or require specific passing depending on the model's vLLM integration.
# The `llm.generate` method typically handles a list of string prompts.
# For full multimodal interaction with vLLM, refer to the original EARL GitHub:
# https://github.com/saba96/EARL/blob/main/emu3/train_image_editing/vllm_inference.py

# This example illustrates the textual part of inference with vLLM,
# assuming the model's vLLM integration handles the image input when loading the model.
# A full end-to-end vLLM multimodal inference might look slightly different.
outputs = llm.generate([prompt], sampling_params) # Pass prompt as a list for vLLM

response = outputs[0].outputs[0].text
print(f'User: {question}
Assistant: {response}')

Citation

If you find our work helpful or inspiring, please feel free to cite it.

@article{saba2025earl,
  title={The Promise of RL for Autoregressive Image Editing},
  author={Saba, Daniel and Tang, Sifei and Huang, Yifan and Liu, Meng and Ma, Jinxin and Liu, Zhian and Fu, Ruifeng and Zhu, Lei and Han, Jun and Zhang, Shang-Wen and Liu, Jing},
  journal={arXiv preprint arXiv:2508.01119},
  year={2025}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support