MeowSky49887 commited on
Commit
ac9eabd
·
verified ·
1 Parent(s): a068d5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  from sklearn.utils import shuffle
3
  from sklearn.model_selection import train_test_split
4
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
- from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification
6
  import torch
7
  import gradio as gr
8
  from pathlib import Path
@@ -30,24 +30,36 @@ data['Label'] = data['Label'].map(groups)
30
 
31
  seeds = [1, 2, 3, 4]
32
 
 
 
 
 
 
 
33
  # Translate function
34
- async def translate_all(seed, texts, language):
35
- dest_lang = "ja" if language == "Japanese" else "th"
36
- semaphore = asyncio.Semaphore(12)
 
 
37
 
38
  async def sem_translate_task(text, idx):
39
  async with semaphore:
40
- async with Translator() as translator:
41
- result = await translator.translate(text, src='en', dest=dest_lang)
42
- return result.text, idx
 
 
 
 
 
 
 
 
 
43
 
44
  tasks = [asyncio.create_task(sem_translate_task(text, idx)) for idx, text in enumerate(texts)]
45
- translated = [None] * len(texts)
46
-
47
- for coro in tqdm_asyncio.as_completed(tasks, total=len(tasks)):
48
- result, index = await coro
49
- translated[index] = result
50
-
51
  return translated
52
 
53
  # Sample Dataset
 
2
  from sklearn.utils import shuffle
3
  from sklearn.model_selection import train_test_split
4
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
6
  import torch
7
  import gradio as gr
8
  from pathlib import Path
 
30
 
31
  seeds = [1, 2, 3, 4]
32
 
33
+ # โหลด model และ tokenizer ของ NLLB
34
+ translation_model_name = "facebook/nllb-200-distilled-600M"
35
+ translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
36
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
37
+ translation_model.eval()
38
+
39
  # Translate function
40
+ async def translate_all(seed, texts, language, progress=gr.Progress(track_tqdm=True)):
41
+ dest_lang = "jpn_Jpan" if language == "Japanese" else "tha_Thai"
42
+ semaphore = asyncio.Semaphore(3) # จำกัด concurrent request (GPU load)
43
+
44
+ translated = [None] * len(texts)
45
 
46
  async def sem_translate_task(text, idx):
47
  async with semaphore:
48
+ inputs = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
49
+ if torch.cuda.is_available():
50
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
51
+ with torch.no_grad():
52
+ translated_tokens = translation_model.generate(
53
+ **inputs,
54
+ forced_bos_token_id=translation_tokenizer.lang_code_to_id[dest_lang]
55
+ )
56
+ translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
57
+ translated[idx] = translated_text
58
+ progress(idx + 1, len(texts))
59
+ await asyncio.sleep(0) # ให้ async loop ทำงาน
60
 
61
  tasks = [asyncio.create_task(sem_translate_task(text, idx)) for idx, text in enumerate(texts)]
62
+ await asyncio.gather(*tasks)
 
 
 
 
 
63
  return translated
64
 
65
  # Sample Dataset