Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| # モデルとプロセッサの読み込み | |
| model_path = "sbintuitions/sarashina2-vision-8b" | |
| print("モデルを読み込んでいます...") | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| # デバイスの設定 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"使用デバイス: {device}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| model = model.to(device) | |
| print("モデルの読み込みが完了しました!") | |
| def describe_image(image): | |
| """ | |
| 画像を受け取り、日本語で説明を生成する関数 | |
| """ | |
| if image is None: | |
| return "画像をアップロードしてください。" | |
| # 画像をPIL形式に変換(GradioはすでにPIL.Imageとして渡してくれる) | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| # プロンプトの作成 | |
| message = [{"role": "user", "content": "この画像について詳しく説明してください。"}] | |
| text_prompt = processor.apply_chat_template(message, add_generation_prompt=True) | |
| # 入力の準備 | |
| inputs = processor( | |
| text=[text_prompt], | |
| images=[image], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(model.device) | |
| # 停止条件の設定 | |
| stopping_criteria = processor.get_stopping_criteria(["\n###"]) | |
| # 推論の実行 | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| do_sample=True, | |
| stopping_criteria=stopping_criteria, | |
| ) | |
| # 生成されたテキストの抽出 | |
| generated_ids = [ | |
| output_ids[len(input_ids):] | |
| for input_ids, output_ids in zip(inputs.input_ids, output_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return output_text[0] | |
| # Gradioインターフェースの作成 | |
| demo = gr.Interface( | |
| fn=describe_image, | |
| inputs=gr.Image(type="pil", label="画像をアップロード"), | |
| outputs=gr.Textbox(label="画像の説明", lines=10), | |
| title="Sarashina2-Vision 画像説明ツール", | |
| description="画像をアップロードすると、Sarashina2-Vision-8Bが日本語で詳しく説明します。", | |
| examples=[ | |
| # サンプル画像のパスを指定する場合はここに追加 | |
| ], | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |