Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from sae_auto_interp.sae import Sae | |
| from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae | |
| from sae_auto_interp.features.features import upsample_mask | |
| import torch | |
| from transformers import AutoTokenizer | |
| from PIL import Image | |
| CITATION_BUTTON_TEXT = """ | |
| @misc{zhang2024largemultimodalmodelsinterpret, | |
| title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models}, | |
| author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu}, | |
| year={2024}, | |
| eprint={2411.14982}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV}, | |
| url={https://arxiv.org/abs/2411.14982}, | |
| } | |
| """ | |
| cached_tensor = None | |
| topk_indices = None | |
| sunglasses_file_path = "assets/sunglasses.jpg" | |
| greedy_file_path = "assets/greedy.jpg" | |
| railway_file_path = "assets/railway.jpg" | |
| def generate_activations(image): | |
| prompt = "<image>" | |
| inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) | |
| global cached_tensor, topk_indices | |
| def hook(module: torch.nn.Module, _, outputs): | |
| global cached_tensor, topk_indices | |
| # Maybe unpack tuple outputs | |
| if isinstance(outputs, tuple): | |
| unpack_outputs = list(outputs) | |
| else: | |
| unpack_outputs = list(outputs) | |
| latents = sae.pre_acts(unpack_outputs[0]) | |
| # When the tokenizer is llama and text is None (image only) | |
| # I skip the first bos tokens | |
| if "llama" in tokenizer.name_or_path: | |
| latents = latents[:, 1:, :] | |
| topk = torch.topk( | |
| latents, k=sae.cfg.k, dim=-1 | |
| ) | |
| # make all other values 0 | |
| result = torch.zeros_like(latents) | |
| # results (bs, seq, num_latents) | |
| result.scatter_(-1, topk.indices, topk.values) | |
| cached_tensor = result.detach().cpu() | |
| topk_indices = ( | |
| latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu() | |
| ) | |
| handles = [hooked_module.register_forward_hook(hook)] | |
| try: | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=inputs["input_ids"].to("cuda"), | |
| pixel_values=inputs["pixel_values"].to("cuda"), | |
| image_sizes=inputs["image_sizes"].to("cuda"), | |
| attention_mask=inputs["attention_mask"].to("cuda"), | |
| ) | |
| finally: | |
| for handle in handles: | |
| handle.remove() | |
| print(cached_tensor.shape) | |
| torch.cuda.empty_cache() | |
| return topk_indices | |
| def visualize_activations(image, feature_num): | |
| base_img_tokens = 576 | |
| patch_size = 24 | |
| # Using Cached tensor | |
| # select the feature_num-th feature | |
| # Then keeping the first 576 tokens | |
| base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size) | |
| upsampled_image_mask = upsample_mask(base_image_activations, (336, 336)) | |
| background = Image.new("L", (336, 336), 0).convert("RGB") | |
| # Somehow as I looked closer into the llava-hf preprocessing code, | |
| # I found out that they don't use the padded image as the base image feat | |
| # but use the simple resized image. This is different from original llava but | |
| # we align to llava-hf for now as we use llava-hf | |
| resized_image = image.resize((336, 336)) | |
| activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB") | |
| return activation_images | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Large Multi-modal Models Can Interpret Features in Large Multi-modal Models | |
| π [ArXiv Paper](https://arxiv.org/abs/2411.14982) | π [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | π€ [Huggingface Collections](https://huggingface.co/collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3) | |
| """ | |
| ) | |
| with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
| with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(type="pil", interactive=True, label="Sample Image") | |
| topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features") | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton([image, topk_features], value="Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features]) | |
| with gr.Column(): | |
| output = gr.Image(label="Activation Visualization") | |
| feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True) | |
| visualize_btn = gr.Button("Visualize", variant="primary") | |
| visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output]) | |
| dummy_text = gr.Textbox(visible=False, label="Explanation") | |
| gr.Examples( | |
| [ | |
| ["assets/sunglasses.jpg", 10, "Sunglasses"], | |
| ["assets/greedy.jpg", 14, "Greedy eating"], | |
| ["assets/railway.jpg", 28, "Railway tracks"], | |
| ], | |
| inputs=[image, feature_num, dummy_text], | |
| label="Examples", | |
| ) | |
| with gr.TabItem("Steering Model", elem_id="steering", id=2): | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| with gr.Accordion("π Citation", open=False): | |
| gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```") | |
| if __name__ == "__main__": | |
| tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf") | |
| sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24") | |
| model, processor = maybe_load_llava_model( | |
| "llava-hf/llama3-llava-next-8b-hf", | |
| rank=0, | |
| dtype=torch.bfloat16, | |
| hf_token=None | |
| ) | |
| hooked_module = model.language_model.get_submodule("model.layers.24") | |
| demo.launch() | |