Update main.py
Browse files
main.py
CHANGED
|
@@ -19,7 +19,11 @@ import tempfile
|
|
| 19 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 20 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 21 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
| 22 |
-
os.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Silence all transformers and huggingface logging
|
| 25 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
@@ -30,19 +34,22 @@ app = Flask(__name__)
|
|
| 30 |
CORS(app)
|
| 31 |
|
| 32 |
# ========== Load Whisper Model (quantized) ==========
|
| 33 |
-
def load_whisper_model(model_size="small"):
|
|
|
|
| 34 |
model_name = f"openai/whisper-{model_size}"
|
| 35 |
-
processor = WhisperProcessor.from_pretrained(model_name)
|
| 36 |
-
model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
| 37 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 38 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
return processor, model
|
| 40 |
|
|
|
|
| 41 |
# ========== Load Grammar Correction Model (quantized) ==========
|
| 42 |
-
def load_grammar_model():
|
|
|
|
| 43 |
model_name = "prithivida/grammar_error_correcter_v1"
|
| 44 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 45 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 46 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 47 |
grammar_pipeline = pipeline(
|
| 48 |
"text2text-generation",
|
|
|
|
| 19 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 20 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 21 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
| 22 |
+
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
|
| 23 |
+
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface"
|
| 24 |
+
|
| 25 |
+
for path in os.environ.values():
|
| 26 |
+
os.makedirs(path, exist_ok=True)
|
| 27 |
|
| 28 |
# Silence all transformers and huggingface logging
|
| 29 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
|
|
| 34 |
CORS(app)
|
| 35 |
|
| 36 |
# ========== Load Whisper Model (quantized) ==========
|
| 37 |
+
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
|
| 38 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 39 |
model_name = f"openai/whisper-{model_size}"
|
| 40 |
+
processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
|
| 41 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
|
| 42 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 43 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
return processor, model
|
| 45 |
|
| 46 |
+
|
| 47 |
# ========== Load Grammar Correction Model (quantized) ==========
|
| 48 |
+
def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"):
|
| 49 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 50 |
model_name = "prithivida/grammar_error_correcter_v1"
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
|
| 52 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
|
| 53 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 54 |
grammar_pipeline = pipeline(
|
| 55 |
"text2text-generation",
|