Nitro-AR / README.md
shizding's picture
Update README.md
ea89c55 verified
---
license: apache-2.0
pipeline_tag: text-to-image
---
# AMD Nitro-AR
<img src="images/fig4.png" width="100%" alt="Nitro-AR" />
## Introduction
Nitro-AR is a compact masked autoregressive (AR) model for high-quality text-to-image generation. With just ~0.3B parameters, Nitro-AR matches the performance of diffusion-based counterparts while significantly reducing inference latency. By combining continuous token prediction with efficient sampling strategies, it achieves diffusion-level image quality with ultra-low latency. On a single AMD Radeon™ RX 7900 XT GPU, Nitro-AR can generate a 512×512 image in as little as 0.23 seconds (standard) or 74ms (optimized). The release consists of:
* **Nitro-AR-512px-GAN.safetensors**: A masked autoregressive model with a diffusion MLP head, optimized via adversarial fine-tuning for few-step diffusion head sampling.
* **Nitro-AR-512px-Joint-GAN.safetensors**: A high-performance variant using **Joint Sampling** (transformer-based head) and adversarial fine-tuning, enabling single-step AR generation with superior consistency.
⚡️ **[Open-source code](https://github.com/AMD-AGI/Nitro-E/tree/main/Nitro-AR)!**
## Details
**Model architecture:** Nitro-AR is built on the **E-MMDiT** (Efficient Multimodal Diffusion Transformer) backbone, adapted to a masked modeling autoregressive framework. Unlike traditional AR models that predict discrete tokens, Nitro-AR employs a **diffusion prediction head** to sample continuous tokens, ensuring compatibility with VAE models.
To achieve high efficiency and quality, we introduce several key optimizations:
* **Joint Sampling:** Replaces the standard MLP head with a small transformer head to model token dependencies, significantly improving coherence in few-step generation.
* **Global Token:** Enhances diversity by using a global token as the initial state for masked tokens.
* **Adversarial Fine-Tuning:** Treats the diffusion prediction head as a generator trained against a discriminator, compressing the diffusion process to just 3–6 denoising steps.
**Dataset:** Our models were trained on a dataset of ~25M images consisting of both real and synthetic data sources that are openly available on the internet. We make use of the following datasets for training: Segment-Anything-1B, JourneyDB, DiffusionDB and DataComp as prompt of the generated data.
**Training cost:** The models were trained using **AMD Instinct™ MI325X GPUs**. The efficient architecture allows for rapid convergence comparable to the Nitro-E family.
## Quickstart
### Environment
When running on AMD Instinct™ GPUs, it is recommended to use the [public PyTorch ROCm images](https://hub.docker.com/r/rocm/pytorch-training/) for optimized performance.
```bash
pip install diffusers==0.32.2 transformers==4.49.0 accelerate==1.7.0 wandb torchmetrics pycocotools torchmetrics[image] mosaicml-streaming==0.11.0 beautifulsoup4 tabulate timm==0.9.1 pyarrow einops omegaconf sentencepiece==0.2.0 pandas==2.2.3 alive-progress ftfy peft safetensors
```
### Model Inference
Download the checkpoints (`Nitro-AR-512px-GAN.safetensors` and `Nitro-AR-512px-Joint-GAN.safetensors`) from [Hugging Face](https://huggingface.co/amd/Nitro-AR) and place them in the `ckpts/` directory.
We support two optimized model variants: `gan` and `joint_gan`.
```python
from core.tools.inference_pipe import init_pipe
import torch
device = torch.device('cuda')
dtype = torch.bfloat16
# Initialize pipeline for 'gan' model (Adversarial Fine-tuned, 3 steps)
pipe = init_pipe(device, dtype, model_type='gan')
# Run inference
prompt = "a photo of a dog, with a white background"
output = pipe(prompt=prompt)
output.images[0].save("output_gan.png")
# Initialize pipeline for 'joint_gan' model (Joint Sampling + Adversarial Fine-tuned, 6 steps)
pipe_joint = init_pipe(device, dtype, model_type='joint_gan')
# Run inference
output_joint = pipe_joint(prompt=prompt)
output_joint.images[0].save("output_joint_gan.png")
```
## License
Copyright (C) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
This project is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0).