MohamedRashad commited on
Commit
8c225c3
·
1 Parent(s): ffcc265

Reorganize import statements in app.py for better readability

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -11,16 +11,6 @@ import torch
11
  import yaml
12
  from box import Box
13
 
14
- from src.data.datapath import Datapath
15
- from src.data.dataset import DatasetConfig, UniRigDatasetModule
16
- from src.data.extract import extract_builtin, get_files
17
- from src.data.transform import TransformConfig
18
- from src.inference.download import download
19
- from src.model.parse import get_model
20
- from src.system.parse import get_system, get_writer
21
- from src.tokenizer.parse import get_tokenizer
22
- from src.tokenizer.spec import TokenizerConfig
23
-
24
  # Get the PyTorch and CUDA versions
25
  torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
26
  cuda_version = torch.version.cuda
@@ -35,6 +25,17 @@ else:
35
  subprocess.run(f'pip install spconv{spconv_version}', shell=True)
36
  subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Helper functions
39
  def validate_input_file(file_path: str) -> bool:
40
  """Validate if the input file format is supported."""
 
11
  import yaml
12
  from box import Box
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Get the PyTorch and CUDA versions
15
  torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
16
  cuda_version = torch.version.cuda
 
25
  subprocess.run(f'pip install spconv{spconv_version}', shell=True)
26
  subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
27
 
28
+ from src.data.datapath import Datapath
29
+ from src.data.dataset import DatasetConfig, UniRigDatasetModule
30
+ from src.data.extract import extract_builtin, get_files
31
+ from src.data.transform import TransformConfig
32
+ from src.inference.download import download
33
+ from src.model.parse import get_model
34
+ from src.system.parse import get_system, get_writer
35
+ from src.tokenizer.parse import get_tokenizer
36
+ from src.tokenizer.spec import TokenizerConfig
37
+
38
+
39
  # Helper functions
40
  def validate_input_file(file_path: str) -> bool:
41
  """Validate if the input file format is supported."""