Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer | |
| from .rope import precompute_freqs_cis | |
| from .text import lm_head, text_decoder, text_encoder | |
| from .vision import encode_image | |
| from .weights import load_from_pt, load_from_safetensors | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image", "-i", type=str, required=True) | |
| parser.add_argument("--prompt", "-p", type=str, required=True) | |
| parser.add_argument("--model", "-m", type=str, required=True) | |
| parser.add_argument("--config", "-c", type=str, default="{}") | |
| parser.add_argument("--max-tokens", "-t", type=int, default=200) | |
| parser.add_argument("--sampler", "-s", type=str, default="greedy") | |
| args = parser.parse_args() | |
| if torch.cuda.is_available(): | |
| torch.set_default_device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| torch.set_default_device("mps") | |
| # Load config. | |
| config = json.loads(args.config) | |
| text_n_heads = config.get("text_n_heads", 32) | |
| # Load model. | |
| model_path = args.model | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| if model_path.endswith(".pt"): | |
| model = load_from_pt(model_path, **config) | |
| elif model_path.endswith(".safetensors"): | |
| model = load_from_safetensors(model_path, **config) | |
| else: | |
| raise ValueError(f"Invalid model format: {model_path}") | |
| # Encode image. | |
| image_path = args.image | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found at {image_path}") | |
| image = Image.open(image_path) | |
| image = image.resize((378, 378)) | |
| image_tensor = encode_image(image, model.vision) | |
| # Encode text, and create inputs_embeds. | |
| tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") | |
| prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:" | |
| input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"] | |
| input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1) | |
| inputs_embeds = text_encoder(input_ids, model.text) | |
| inputs_embeds = torch.cat( | |
| [ | |
| inputs_embeds[:, 0:1, :], | |
| image_tensor.unsqueeze(0), | |
| inputs_embeds[:, 1:, :], | |
| ], | |
| dim=1, | |
| ) | |
| kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16) | |
| freqs_cis = precompute_freqs_cis(32, 2048) | |
| pos = 0 | |
| for _ in range(args.max_tokens): | |
| with torch.no_grad(): | |
| hidden, kv_cache_update = text_decoder( | |
| inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis | |
| ) | |
| logits = lm_head(hidden, model.text) | |
| kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = ( | |
| kv_cache_update | |
| ) | |
| pos += kv_cache_update.size(-2) | |
| if args.sampler == "multinomial": | |
| next_token = torch.multinomial( | |
| torch.softmax(logits, dim=-1), num_samples=1 | |
| ).squeeze(0) | |
| elif args.sampler == "greedy": | |
| next_token = torch.argmax(logits, dim=-1) | |
| else: | |
| raise ValueError(f"Invalid sampler: {args.sampler}") | |
| if next_token == tokenizer.eos_token_id: | |
| print() | |
| break | |
| input_ids = next_token.unsqueeze(0) | |
| inputs_embeds = text_encoder(input_ids, model.text) | |
| output_text = tokenizer.batch_decode(input_ids)[0] | |
| print(output_text, end="", flush=True) | |