atsyplikhin commited on
Commit
7390e09
·
1 Parent(s): 88df8d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -8,7 +8,10 @@ def produce_art(prompt):
8
  import torch
9
 
10
  model_id = "atsyplikhin/rita_sd_model"
11
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)#.to("cuda")
 
 
 
12
 
13
  bs = 1
14
  images = pipe([prompt]*bs, num_inference_steps=50, guidance_scale=7.5)
 
8
  import torch
9
 
10
  model_id = "atsyplikhin/rita_sd_model"
11
+ if torch.cuda.is_available():
12
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
13
+ else:
14
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
15
 
16
  bs = 1
17
  images = pipe([prompt]*bs, num_inference_steps=50, guidance_scale=7.5)