AbdoIR commited on
Commit
11712e0
·
verified ·
1 Parent(s): 48d8acc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -7
main.py CHANGED
@@ -19,7 +19,11 @@ import tempfile
19
  os.environ["HF_HOME"] = "/tmp/huggingface"
20
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
21
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
22
- os.makedirs(os.environ["HF_HOME"], exist_ok=True)
 
 
 
 
23
 
24
  # Silence all transformers and huggingface logging
25
  logging.getLogger("transformers").setLevel(logging.ERROR)
@@ -30,19 +34,22 @@ app = Flask(__name__)
30
  CORS(app)
31
 
32
  # ========== Load Whisper Model (quantized) ==========
33
- def load_whisper_model(model_size="small"):
 
34
  model_name = f"openai/whisper-{model_size}"
35
- processor = WhisperProcessor.from_pretrained(model_name)
36
- model = WhisperForConditionalGeneration.from_pretrained(model_name)
37
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
38
  model.to("cuda" if torch.cuda.is_available() else "cpu")
39
  return processor, model
40
 
 
41
  # ========== Load Grammar Correction Model (quantized) ==========
42
- def load_grammar_model():
 
43
  model_name = "prithivida/grammar_error_correcter_v1"
44
- tokenizer = AutoTokenizer.from_pretrained(model_name)
45
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
46
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
47
  grammar_pipeline = pipeline(
48
  "text2text-generation",
 
19
  os.environ["HF_HOME"] = "/tmp/huggingface"
20
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
21
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
22
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
23
+ os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface"
24
+
25
+ for path in os.environ.values():
26
+ os.makedirs(path, exist_ok=True)
27
 
28
  # Silence all transformers and huggingface logging
29
  logging.getLogger("transformers").setLevel(logging.ERROR)
 
34
  CORS(app)
35
 
36
  # ========== Load Whisper Model (quantized) ==========
37
+ def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
38
+ os.makedirs(save_dir, exist_ok=True)
39
  model_name = f"openai/whisper-{model_size}"
40
+ processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
41
+ model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
42
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
43
  model.to("cuda" if torch.cuda.is_available() else "cpu")
44
  return processor, model
45
 
46
+
47
  # ========== Load Grammar Correction Model (quantized) ==========
48
+ def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"):
49
+ os.makedirs(save_dir, exist_ok=True)
50
  model_name = "prithivida/grammar_error_correcter_v1"
51
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
52
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
53
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
54
  grammar_pipeline = pipeline(
55
  "text2text-generation",