oggata commited on
Commit
0324a51
·
verified ·
1 Parent(s): 687de38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -8,12 +8,17 @@ model_path = "sbintuitions/sarashina2-vision-8b"
8
 
9
  print("モデルを読み込んでいます...")
10
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_path,
13
- device_map="auto",
14
- torch_dtype=torch.float16,
15
  trust_remote_code=True,
16
  )
 
17
  print("モデルの読み込みが完了しました!")
18
 
19
  def describe_image(image):
 
8
 
9
  print("モデルを読み込んでいます...")
10
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
11
+
12
+ # デバイスの設定
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"使用デバイス: {device}")
15
+
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_path,
18
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
19
  trust_remote_code=True,
20
  )
21
+ model = model.to(device)
22
  print("モデルの読み込みが完了しました!")
23
 
24
  def describe_image(image):