Spaces:
Running
Running
| import pandas as pd | |
| from sklearn.utils import shuffle | |
| from tqdm.asyncio import tqdm_asyncio | |
| from googletrans import Translator | |
| from pathlib import Path | |
| import asyncio | |
| import random | |
| import os | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import gradio as gr | |
| # --- โหลด dataset | |
| data = pd.read_parquet("hf://datasets/boltuix/emotions-dataset/emotions_dataset.parquet") | |
| groups = { | |
| "neutral": "neutral", | |
| "anger": "angry", | |
| "love": "joy", | |
| "happiness": "fun", | |
| "sadness": "sorrow", | |
| "surprise": "surprised", | |
| "fear": "fear", | |
| "disgust": "disgust" | |
| } | |
| data = data[data['Label'].isin(groups.keys())].copy() | |
| data['Label'] = data['Label'].map(groups) | |
| seeds = [1, 2, 3, 4] | |
| # --- Translate function (batch + executor + async + progress) | |
| async def translate_all(seed, texts, language): | |
| # mapping language → Google Translate target code | |
| target_lang = "ja" if language == "Japanese" else "th" | |
| # จำกัดการยิง request เพื่อป้องกัน timeout | |
| semaphore = asyncio.Semaphore(8) | |
| async def sem_translate_task(text, idx): | |
| async with semaphore: | |
| for attempt in range(5): # retry 5 ครั้ง | |
| try: | |
| async with Translator() as translator: | |
| result = await translator.translate( | |
| text, | |
| src="en", | |
| dest=target_lang | |
| ) | |
| return result.text, idx | |
| except Exception: | |
| # backoff (เพิ่มขึ้นเรื่อย ๆ) | |
| await asyncio.sleep(1 + attempt * 0.5 + random.random() * 0.3) | |
| # ถ้าล้มเหลว 5 ครั้ง → คืนข้อความเดิมกันล่ม | |
| return text, idx | |
| # สร้าง tasks | |
| tasks = [asyncio.create_task(sem_translate_task(text, idx)) | |
| for idx, text in enumerate(texts)] | |
| translated = [None] * len(texts) | |
| # tqdm async สำหรับ progress bar | |
| for coro in tqdm_asyncio.as_completed(tasks, total=len(tasks)): | |
| result, index = await coro | |
| translated[index] = result | |
| return translated | |
| # --- Sample Dataset | |
| async def sample_all(language, progress=gr.Progress(track_tqdm=True)): | |
| files = [] | |
| for seed in seeds: | |
| try: | |
| filename = f"./data/{language}_SampledData_{seed}.csv" | |
| Path("./data").mkdir(parents=True, exist_ok=True) | |
| if not os.path.exists(filename): | |
| sampled = ( | |
| data.groupby('Label', group_keys=False) | |
| .apply(lambda x: x.sample(n=1000, random_state=int(seed))) | |
| ) | |
| sampled = shuffle(sampled).reset_index(drop=True) | |
| texts = sampled["Sentence"].tolist() | |
| translated = await translate_all(seed, texts, language) | |
| sampled["Sentence"] = translated | |
| sampled.to_csv(filename, index=False) | |
| files.append(filename) | |
| else: | |
| files.append(filename) | |
| except Exception as e: | |
| raise gr.Error(e) | |
| return files | |
| # --- Dataset class | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels=None): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| if self.labels is not None: | |
| item["labels"] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.encodings["input_ids"]) | |
| # --- Prepare dataset for Trainer | |
| def prepare_dataset(df, tokenizer): | |
| label_list = sorted(df["Label"].unique()) | |
| label2id = {label: i for i, label in enumerate(label_list)} | |
| id2label = {i: label for label, i in label2id.items()} | |
| X = list(df["Sentence"]) | |
| y = [label2id[label] for label in df["Label"]] | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X, y, test_size=0.2, stratify=y, random_state=42 | |
| ) | |
| X_train_tokenized = tokenizer( | |
| X_train, padding=True, truncation=True, max_length=512 | |
| ) | |
| X_val_tokenized = tokenizer( | |
| X_val, padding=True, truncation=True, max_length=512 | |
| ) | |
| train_dataset = Dataset(X_train_tokenized, y_train) | |
| val_dataset = Dataset(X_val_tokenized, y_val) | |
| return train_dataset, val_dataset, label2id, id2label | |
| # --- Compute metrics | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = logits.argmax(axis=-1) | |
| accuracy = accuracy_score(labels, preds) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels, preds, average='weighted' | |
| ) | |
| return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1} | |
| # --- Train model per language and seed | |
| def train_model(language): | |
| model_name = "Geotrend/distilbert-base-th-cased" if language == "Japanese" else "Geotrend/distilbert-base-th-cased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| metric_str_all = [] | |
| model_paths = [] | |
| for seed in seeds: | |
| csv_path = f"./data/{language}_SampledData_{seed}.csv" | |
| if not os.path.exists(csv_path): | |
| return f"File {csv_path} not found! กรุณาเตรียม Dataset ก่อน.", None | |
| df = pd.read_csv(csv_path) | |
| train_dataset, val_dataset, label2id, id2label = prepare_dataset(df, tokenizer) | |
| output_dir = f"./output/{language}/seed{seed}" | |
| final_model_dir = os.path.join(output_dir, "final_model") | |
| Path(final_model_dir).mkdir(parents=True, exist_ok=True) | |
| if not os.path.exists(os.path.join(final_model_dir, "pytorch_model.bin")): | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| use_safetensors=True, | |
| num_labels=len(label2id), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| seed=seed, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| num_train_epochs=5, | |
| fp16=False, | |
| logging_dir=f"./logs/{language}_seed{seed}", | |
| logging_steps=100, | |
| load_best_model_at_end=True | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics | |
| ) | |
| trainer.train() | |
| trainer.save_model(final_model_dir) | |
| tokenizer.save_pretrained(final_model_dir) | |
| model_paths.append(final_model_dir) | |
| metrics = trainer.evaluate() | |
| metric_str = f"Seed {seed} ({language}):\n" + "\n".join( | |
| [f"{k}: {v:.4f}" for k, v in metrics.items()] | |
| ) | |
| metric_str_all.append(metric_str) | |
| else: | |
| model_paths.append(final_model_dir) | |
| final_avg_dir = f"./VRM-Emotions/{language}" | |
| Path(final_avg_dir).mkdir(parents=True, exist_ok=True) | |
| return "\n\n".join(metric_str_all), model_paths, final_avg_dir | |
| # --- Async wrapper for training | |
| async def train_model_async(language, progress=gr.Progress(track_tqdm=True)): | |
| return await asyncio.to_thread(train_model, language) | |
| # --- Gradio UI | |
| with gr.Blocks() as demo: | |
| # Tab 1: Options | |
| with gr.Tab("Options"): | |
| language_dropdown = gr.Dropdown( | |
| choices=["Japanese", "Thai"], | |
| label="Language", | |
| value="Japanese" | |
| ) | |
| # Tab 2: Prepare Dataset | |
| with gr.Tab("Prepare Dataset"): | |
| dataset_files = gr.Files(label="CSV Files") | |
| sample_btn = gr.Button("Get Datasets") | |
| sample_btn.click(sample_all, inputs=language_dropdown, outputs=dataset_files) | |
| # Tab 3: Train Model | |
| with gr.Tab("Train Model"): | |
| train_results = gr.TextArea(label="Metrics", interactive=False) | |
| models_files = gr.Files(label="Trained Model") | |
| train_btn = gr.Button("Train All") | |
| train_btn.click(train_model_async, inputs=language_dropdown, outputs=[train_results, models_files]) | |
| demo.launch(debug=True) | |