akhaliq HF Staff commited on
Commit
5fe7560
·
verified ·
1 Parent(s): 462c8e0

Update app.py from anycoder

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -3,12 +3,12 @@ import torch
3
  import gradio as gr
4
  import spaces
5
 
6
- # Load models
7
  model2 = torch.hub.load(
8
  "AK391/animegan2-pytorch:main",
9
  "generator",
10
  pretrained=True,
11
- device="cpu", # Load on CPU initially
12
  progress=False
13
  )
14
 
@@ -16,29 +16,32 @@ model1 = torch.hub.load(
16
  "AK391/animegan2-pytorch:main",
17
  "generator",
18
  pretrained="face_paint_512_v1",
19
- device="cpu" # Load on CPU initially
20
  )
21
 
22
  face2paint = torch.hub.load(
23
  'AK391/animegan2-pytorch:main',
24
  'face2paint',
25
  size=512,
26
- device="cpu", # Load on CPU initially
27
  side_by_side=False
28
  )
29
 
30
  @spaces.GPU # Zero GPU decorator - moves to GPU only when function runs
31
  def inference(img, ver):
32
  """Convert portrait to anime style"""
33
- # Move models to GPU when function is called
34
  if ver == 'Version 2':
35
- model2.to('cuda')
36
- out = face2paint(model2, img)
37
- model2.to('cpu') # Move back to CPU to free GPU
38
  else:
39
- model1.to('cuda')
40
- out = face2paint(model1, img)
41
- model1.to('cpu') # Move back to CPU to free GPU
 
 
 
 
 
42
  return out
43
 
44
  # Custom CSS for modern, mobile-friendly design
 
3
  import gradio as gr
4
  import spaces
5
 
6
+ # Load models on CPU initially
7
  model2 = torch.hub.load(
8
  "AK391/animegan2-pytorch:main",
9
  "generator",
10
  pretrained=True,
11
+ device="cpu",
12
  progress=False
13
  )
14
 
 
16
  "AK391/animegan2-pytorch:main",
17
  "generator",
18
  pretrained="face_paint_512_v1",
19
+ device="cpu"
20
  )
21
 
22
  face2paint = torch.hub.load(
23
  'AK391/animegan2-pytorch:main',
24
  'face2paint',
25
  size=512,
26
+ device="cpu",
27
  side_by_side=False
28
  )
29
 
30
  @spaces.GPU # Zero GPU decorator - moves to GPU only when function runs
31
  def inference(img, ver):
32
  """Convert portrait to anime style"""
33
+ # Select model based on version
34
  if ver == 'Version 2':
35
+ model = model2
 
 
36
  else:
37
+ model = model1
38
+
39
+ # Move model to GPU (spaces.GPU handles the device)
40
+ model = model.to('cuda')
41
+
42
+ # Process image - face2paint will handle the device internally
43
+ out = face2paint(model, img, device='cuda')
44
+
45
  return out
46
 
47
  # Custom CSS for modern, mobile-friendly design