|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
<center> <div style="text-align: center;"> <img src="https://raw.githubusercontent.com/ZHZisZZ/dllm/main/assets/logo.gif" width="400" /> |
|
|
</div> </center> |
|
|
|
|
|
# ModernBERT-large-chat-v0.1 |
|
|
|
|
|
ModernBERT-large-chat-v0.1 is a diffusion-based language model adapted from [ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-base) using [MDLM](https://arxiv.org/abs/2406.07524) (masked diffusion), trained with the [dLLM](https://github.com/ZHZisZZ/dllm) framework. |
|
|
|
|
|
## Model Overview |
|
|
|
|
|
ModernBERT-large-chat-v0.1 has the following features: |
|
|
|
|
|
- **Method**: [Masked Diffusion Language Modeling (MDLM)](https://arxiv.org/abs/2406.07524) |
|
|
- **Framework**: [dLLM](https://github.com/ZHZisZZ/dllm) |
|
|
- **Base Model**: [ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large) |
|
|
- **Datasets**: [tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture), [smoltalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) |
|
|
|
|
|
For training details, see the [W&B report](https://wandb.ai/asap-zzhou/dllm/reports/dLLM-BERT--VmlldzoxNDg0MzExNg). |
|
|
|
|
|
## Installation |
|
|
|
|
|
```shell |
|
|
pip install torch transformers accelerate |
|
|
``` |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
|
|
|
|
|
|
def add_gumbel_noise(logits, temperature): |
|
|
if temperature == 0: |
|
|
return logits |
|
|
logits = logits.to(torch.float64) |
|
|
noise = torch.rand_like(logits, dtype=torch.float64) |
|
|
gumbel_noise = (- torch.log(noise)) ** temperature |
|
|
return logits.exp() / gumbel_noise |
|
|
|
|
|
|
|
|
def get_num_transfer_tokens(mask_index, steps): |
|
|
mask_num = mask_index.sum(dim=1, keepdim=True) |
|
|
base = mask_num // steps |
|
|
remainder = mask_num % steps |
|
|
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base |
|
|
|
|
|
for i in range(mask_num.size(0)): |
|
|
num_transfer_tokens[i, :remainder[i]] += 1 |
|
|
return num_transfer_tokens |
|
|
|
|
|
|
|
|
@ torch.no_grad() |
|
|
def generate(model, prompt, steps=128, gen_length=128, block_length=64, temperature=0.0, cfg_scale=0., remasking='random'): |
|
|
mask_id = tokenizer.mask_token_id |
|
|
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) |
|
|
x[:, :prompt.shape[1]] = prompt.clone() |
|
|
prompt_index = (x != mask_id) |
|
|
|
|
|
assert gen_length % block_length == 0 |
|
|
num_blocks = gen_length // block_length |
|
|
assert steps % num_blocks == 0 |
|
|
steps = steps // num_blocks |
|
|
|
|
|
for num_block in range(num_blocks): |
|
|
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) |
|
|
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) |
|
|
for i in range(steps): |
|
|
mask_index = (x == mask_id) |
|
|
if cfg_scale > 0.: |
|
|
un_x = x.clone() |
|
|
un_x[prompt_index] = mask_id |
|
|
x_ = torch.cat([x, un_x], dim=0) |
|
|
logits = model(x_).logits |
|
|
logits, un_logits = torch.chunk(logits, 2, dim=0) |
|
|
logits = un_logits + (cfg_scale + 1) * (logits - un_logits) |
|
|
else: |
|
|
logits = model(x).logits |
|
|
|
|
|
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) |
|
|
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l |
|
|
|
|
|
if remasking == 'low_confidence': |
|
|
p = F.softmax(logits, dim=-1) |
|
|
x0_p = torch.squeeze( |
|
|
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l |
|
|
elif remasking == 'random': |
|
|
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) |
|
|
else: |
|
|
raise NotImplementedError(remasking) |
|
|
|
|
|
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf |
|
|
|
|
|
x0 = torch.where(mask_index, x0, x) |
|
|
confidence = torch.where(mask_index, x0_p, -np.inf) |
|
|
|
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
|
|
for j in range(confidence.shape[0]): |
|
|
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) |
|
|
transfer_index[j, select_index] = True |
|
|
x[transfer_index] = x0[transfer_index] |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
device = 'cuda' |
|
|
model = AutoModelForMaskedLM.from_pretrained('dllm-collection/ModernBERT-large-chat-v0.1', dtype=torch.bfloat16).to(device).eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained('dllm-collection/ModernBERT-large-chat-v0.1') |
|
|
|
|
|
prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" |
|
|
m = [ |
|
|
{"role": "system", "content": "You are a helpful AI assistant."}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) |
|
|
|
|
|
input_ids = tokenizer(prompt)['input_ids'] |
|
|
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) |
|
|
|
|
|
text = generate(model, input_ids, steps=128, gen_length=128, block_length=64, temperature=0.0, cfg_scale=0.0, remasking='random') |
|
|
print(tokenizer.batch_decode(text[:, input_ids.shape[1]:], skip_special_tokens=False)[0]) |
|
|
``` |
|
|
|
|
|
## Generation Parameters |
|
|
|
|
|
| Parameter | Description | Default | |
|
|
| ---------------- | ---------------------------------------------------------------------------------------------- | -------- | |
|
|
| `max_new_tokens` | Number of tokens to generate | 128 | |
|
|
| `steps` | Number of diffusion denoising iterations | 128 | |
|
|
| `temperature` | Sampling temperature; set to `0.0` for deterministic generation | 0.0 | |
|
|
| `block_length` | Token block size used during iterative denoising | 64 | |
|
|
| `cfg_scale` | Classifier-free guidance scale controlling instruction adherence (higher = more deterministic) | 0.0 | |
|
|
| `remasking` | Strategy for re-masking during each denoising step (`random`, `none`, or `confidence`) | `random` | |
|
|
|
|
|
## Command-Line Interface |
|
|
|
|
|
Follow the Github repo's demo script [examples/bert/chat.py](https://github.com/ZHZisZZ/dllm/blob/main/examples/bert/chat.py) for visualized generation: |
|
|
|
|
|
```shell |
|
|
python -u examples/bert/chat.py \ |
|
|
--model_name_or_path dllm-collection/ModernBERT-large-chat-v0.1 \ |
|
|
--chat True |
|
|
``` |
|
|
|
|
|
## Evaluation |
|
|
|βββββββββββββββββββββ| LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU | |
|
|
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| |
|
|
| ModernBERT-base-chat-v0.1 | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 | |
|
|
| ModernBERT-large-chat-v0.1 | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 | |
|
|
|
|
|
<!-- <p align="left" style="color: #808080; font-size: 0.9em;"> |
|
|
Table 1. Evaluation results of |
|
|
ModernBERT-base-chat-v0.1 and |
|
|
ModernBERT-large-chat-v0.1. |
|
|
All results are evaluated using |
|
|
<a href="https://github.com/ZHZisZZ/dllm/tree/main" style="color: #808080; text-decoration: underline;"> |
|
|
dLLM |
|
|
</a>'s eval script |
|
|
<a href="https://github.com/ZHZisZZ/dllm/blob/main/examples/bert/eval.sh" style="color: #808080; text-decoration: underline;"> |
|
|
examples/bert/eval.sh |
|
|
</a>. |
|
|
</p> --> |
|
|
|
|
|
To automatically evaluate ModernBERT-large-chat-v0.1 on all benchmarks, run: |
|
|
```shell |
|
|
bash examples/bert/eval.sh \ |
|
|
--model_name_or_path "dllm-collection/ModernBERT-large-chat-v0.1" |
|
|
``` |
|
|
|
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use ModernBERT-large-chat-v0.1 or dLLM, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{dllm, |
|
|
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song}, |
|
|
title = {dLLM: Simple Diffusion Language Modeling}, |
|
|
year = {2025}, |
|
|
howpublished = {\url{https://github.com/ZHZisZZ/dllm}}, |
|
|
} |