oggata's picture
Update app.py
0324a51 verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# モデルとプロセッサの読み込み
model_path = "sbintuitions/sarashina2-vision-8b"
print("モデルを読み込んでいます...")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# デバイスの設定
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用デバイス: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True,
)
model = model.to(device)
print("モデルの読み込みが完了しました!")
def describe_image(image):
"""
画像を受け取り、日本語で説明を生成する関数
"""
if image is None:
return "画像をアップロードしてください。"
# 画像をPIL形式に変換(GradioはすでにPIL.Imageとして渡してくれる)
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
# プロンプトの作成
message = [{"role": "user", "content": "この画像について詳しく説明してください。"}]
text_prompt = processor.apply_chat_template(message, add_generation_prompt=True)
# 入力の準備
inputs = processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# 停止条件の設定
stopping_criteria = processor.get_stopping_criteria(["\n###"])
# 推論の実行
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
stopping_criteria=stopping_criteria,
)
# 生成されたテキストの抽出
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return output_text[0]
# Gradioインターフェースの作成
demo = gr.Interface(
fn=describe_image,
inputs=gr.Image(type="pil", label="画像をアップロード"),
outputs=gr.Textbox(label="画像の説明", lines=10),
title="Sarashina2-Vision 画像説明ツール",
description="画像をアップロードすると、Sarashina2-Vision-8Bが日本語で詳しく説明します。",
examples=[
# サンプル画像のパスを指定する場合はここに追加
],
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
demo.launch()