Tiny-A2D
Collection
4 items
β’
Updated
β’
1
Qwen3-0.6B-diffusion-bd3lm-v0.1 is a diffusion-based language model adapted from Qwen3-0.6B using BD3LM (block diffusion), trained with the dLLM framework.
Qwen3-0.6B-diffusion-bd3lm-v0.1 has the following features:
For training details, see the W&B report.
pip install torch transformers accelerate
import math
import copy
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
def add_gumbel_noise(logits, temperature):
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
g = (-torch.log(noise)) ** temperature
return logits.exp() / g
def get_num_transfer_tokens(mask_index, steps):
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
rem = mask_num % steps
out = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
for i in range(mask_num.size(0)):
out[i, : rem[i]] += 1
return out
def build_staircase_attention_mask(x, block_size, pad_id):
B, T = x.shape
device = x.device
valid = x != pad_id
pos_raw = torch.cumsum(valid.long(), dim=-1)
position_ids = torch.where(valid, pos_raw - 1, torch.zeros_like(pos_raw)).long()
col = torch.arange(T, device=device)
block_ids = (col // block_size).view(1, T).expand(B, T)
block_ids = torch.where(valid, block_ids, torch.full_like(block_ids, -1))
q = block_ids.view(B, 1, T, 1)
k = block_ids.view(B, 1, 1, T)
attn = (k <= q) & (q >= 0) & (k >= 0)
return attn, position_ids
def diffusion_step_block(logits, x_block, mask_block, num_transfer, temperature, remasking):
B, L, _ = logits.shape
if not mask_block.any():
return x_block
noisy = add_gumbel_noise(logits, temperature)
x0 = noisy.argmax(dim=-1)
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
conf = p.gather(-1, x0.unsqueeze(-1)).squeeze(-1)
elif remasking == "random":
conf = torch.rand((B, L), device=logits.device)
else:
raise ValueError(remasking)
x0 = torch.where(mask_block, x0, x_block)
neg_inf = torch.full_like(conf, -float("inf"))
conf = torch.where(mask_block, conf, neg_inf)
commit = torch.zeros_like(x_block, dtype=torch.bool)
for i in range(B):
k = int(num_transfer[i].item())
if k > 0:
valid = (conf[i] > -float("inf")).sum().item()
k = min(k, valid)
_, idx = torch.topk(conf[i], k)
commit[i, idx] = True
out = x_block.clone()
out[commit] = x0[commit]
return out
@torch.no_grad()
def generate(
model,
tokenizer,
prompt,
steps=128,
max_new_tokens=128,
block_size=32,
temperature=0.0,
cfg_scale=0.0,
remasking="low_confidence",
):
device = model.device
mask_id = tokenizer.mask_token_id
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
if isinstance(prompt, torch.Tensor):
x = prompt.to(device).long()
else:
if isinstance(prompt[0], (list, tuple)):
max_len = max(len(p) for p in prompt)
x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long)
for i, p in enumerate(prompt):
x[i, : len(p)] = torch.tensor(p, device=device)
else:
x = torch.tensor(prompt, device=device).long()
if x.dim() == 1:
x = x.unsqueeze(0)
B = x.size(0)
finished = torch.zeros(B, dtype=torch.bool, device=device)
num_blocks = math.ceil(max_new_tokens / block_size)
steps_per_block = math.ceil(steps / num_blocks)
generated = 0
while generated < max_new_tokens:
if finished.all():
break
T_prefix = x.size(1)
offset = T_prefix % block_size
room = block_size if offset == 0 else block_size - offset
cur_len = min(room, max_new_tokens - generated)
if cur_len <= 0:
break
attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id)
out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
cond_past = out.past_key_values
if cfg_scale > 0:
un_x = x.clone()
un_x[:] = mask_id
out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
uncond_past = out_un.past_key_values
else:
uncond_past = None
block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long)
block[finished] = pad_id
x = torch.cat([x, block], dim=1)
T_total = x.size(1)
block_mask = x[:, -cur_len:] == mask_id
num_transfer = get_num_transfer_tokens(block_mask, steps_per_block)
eff_steps = num_transfer.size(1)
full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id)
attn_blk = full_attn[:, :, T_prefix:T_total, :]
pos_blk = full_pos[:, T_prefix:T_total]
for t in range(eff_steps):
x_blk = x[:, T_prefix:T_total]
m_blk = x_blk == mask_id
cond_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(cond_past), use_cache=False
).logits
logits = cond_logits
if cfg_scale > 0:
un_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(uncond_past), use_cache=False
).logits
logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits)
x_blk_new = diffusion_step_block(
logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking
)
x[:, T_prefix:T_total] = x_blk_new
if tokenizer.eos_token_id is not None:
finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1)
if finished.all():
break
generated += cur_len
if finished.all():
break
return x
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForMaskedLM.from_pretrained("dllm-collection/Qwen3-0.6B-diffusion-bd3lm-v0.1", dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("dllm-collection/Qwen3-0.6B-diffusion-bd3lm-v0.1", trust_remote_code=True)
prompts = [
[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "Implement a DFS traversal in Python with clear inline comments."},
],
[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 10 hours?"},
],
]
encoded = [tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=True) for m in prompts]
prompt_lens = [len(e) for e in encoded]
max_len = max(prompt_lens)
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
input_ids = torch.full((len(encoded), max_len), pad_id, dtype=torch.long)
for i, ids in enumerate(encoded):
input_ids[i, : len(ids)] = torch.tensor(ids, dtype=torch.long)
input_ids = input_ids.to(device)
max_new_tokens = 256
text = generate(
model,
tokenizer,
input_ids,
steps=256,
max_new_tokens=max_new_tokens,
block_size=32,
temperature=0.0,
cfg_scale=0.0,
remasking="low_confidence",
)
new_tokens = [text[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens].tolist() for i in range(len(prompt_lens))]
for idx, decoded in enumerate(tokenizer.batch_decode(new_tokens, skip_special_tokens=False)):
print(f"\n[Sample {idx}]")
print(decoded)
| Parameter | Description | Default |
|---|---|---|
max_new_tokens |
Number of tokens to generate | 256 |
steps |
Number of diffusion denoising iterations | 256 |
temperature |
Sampling temperature; set to 0.0 for deterministic generation |
0.0 |
block_size |
Token block size used during iterative denoising | 32 |
cfg_scale |
Classifier-free guidance scale controlling instruction adherence (higher = more deterministic) | 0.0 |
remasking |
Strategy for re-masking during each denoising step (random or low_confidence) |
low_confidence |
Follow the Github repo's demo script examples/a2d/bd3lm/chat.py for visualized generation:
python -u examples/a2d/bd3lm/chat.py \
--model_name_or_path dllm-collection/Qwen3-0.6B-diffusion-bd3lm-v0.1 \
--chat_template True --block_size 32 --remasking low_confidence --steps 256 --max_new_tokens 256
| Modelβββββββββββββββββββββ | GSM8K | MATH | BBH | MMLUβPro | Hellaswag | MMLU | HumanEval | MBPP |
|---|---|---|---|---|---|---|---|---|
Qwen3-0.6B-diffusion-bd3lm-v0.1 (evaluated) |
46.6 | 13.9 | 27.0 | 14.1 | 40.0 | 38.8 | 47.6 | 32.0 |
Qwen3-0.6B-diffusion-mdlm-v0.1 (evaluated) |
29.8 | 8.8 | 27.0 | 17.6 | 42.1 | 40.0 | 30.5 | 29.2 |
Qwen3-0.6B (reported) |
59.6 | 32.4 | 41.5 | 24.7 | 47.4 | 52.8 | 32.3 | 36.6 |
Qwen2.5-0.5B (reported) |
41.6 | 19.5 | 20.3 | 15.7 | 52.1 | 47.5 | 30.5 | 39.3 |
To automatically evaluate Qwen3-0.6B-diffusion-bd3lm-v0.1 on all benchmarks, run:
bash examples/a2d/bd3lm/eval.sh \
--model_name_or_path dllm-collection/Qwen3-0.6B-diffusion-bd3lm-v0.1
If you use Qwen3-0.6B-diffusion-bd3lm-v0.1 or dLLM, please cite:
@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
title = {dLLM: Simple Diffusion Language Modeling},
year = {2025},
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}