VRM-Emotions / app.py
MeowSky49887's picture
Update app.py
1781bc9 verified
import pandas as pd
from sklearn.utils import shuffle
from tqdm.asyncio import tqdm_asyncio
from googletrans import Translator
from pathlib import Path
import asyncio
import random
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification
import torch
import gradio as gr
# --- โหลด dataset
data = pd.read_parquet("hf://datasets/boltuix/emotions-dataset/emotions_dataset.parquet")
groups = {
"neutral": "neutral",
"anger": "angry",
"love": "joy",
"happiness": "fun",
"sadness": "sorrow",
"surprise": "surprised",
"fear": "fear",
"disgust": "disgust"
}
data = data[data['Label'].isin(groups.keys())].copy()
data['Label'] = data['Label'].map(groups)
seeds = [1, 2, 3, 4]
# --- Translate function (batch + executor + async + progress)
async def translate_all(seed, texts, language):
# mapping language → Google Translate target code
target_lang = "ja" if language == "Japanese" else "th"
# จำกัดการยิง request เพื่อป้องกัน timeout
semaphore = asyncio.Semaphore(8)
async def sem_translate_task(text, idx):
async with semaphore:
for attempt in range(5): # retry 5 ครั้ง
try:
async with Translator() as translator:
result = await translator.translate(
text,
src="en",
dest=target_lang
)
return result.text, idx
except Exception:
# backoff (เพิ่มขึ้นเรื่อย ๆ)
await asyncio.sleep(1 + attempt * 0.5 + random.random() * 0.3)
# ถ้าล้มเหลว 5 ครั้ง → คืนข้อความเดิมกันล่ม
return text, idx
# สร้าง tasks
tasks = [asyncio.create_task(sem_translate_task(text, idx))
for idx, text in enumerate(texts)]
translated = [None] * len(texts)
# tqdm async สำหรับ progress bar
for coro in tqdm_asyncio.as_completed(tasks, total=len(tasks)):
result, index = await coro
translated[index] = result
return translated
# --- Sample Dataset
async def sample_all(language, progress=gr.Progress(track_tqdm=True)):
files = []
for seed in seeds:
try:
filename = f"./data/{language}_SampledData_{seed}.csv"
Path("./data").mkdir(parents=True, exist_ok=True)
if not os.path.exists(filename):
sampled = (
data.groupby('Label', group_keys=False)
.apply(lambda x: x.sample(n=1000, random_state=int(seed)))
)
sampled = shuffle(sampled).reset_index(drop=True)
texts = sampled["Sentence"].tolist()
translated = await translate_all(seed, texts, language)
sampled["Sentence"] = translated
sampled.to_csv(filename, index=False)
files.append(filename)
else:
files.append(filename)
except Exception as e:
raise gr.Error(e)
return files
# --- Dataset class
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels is not None:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
# --- Prepare dataset for Trainer
def prepare_dataset(df, tokenizer):
label_list = sorted(df["Label"].unique())
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}
X = list(df["Sentence"])
y = [label2id[label] for label in df["Label"]]
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
X_train_tokenized = tokenizer(
X_train, padding=True, truncation=True, max_length=512
)
X_val_tokenized = tokenizer(
X_val, padding=True, truncation=True, max_length=512
)
train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
return train_dataset, val_dataset, label2id, id2label
# --- Compute metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits.argmax(axis=-1)
accuracy = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, preds, average='weighted'
)
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
# --- Train model per language and seed
def train_model(language):
model_name = "Geotrend/distilbert-base-th-cased" if language == "Japanese" else "Geotrend/distilbert-base-th-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
metric_str_all = []
model_paths = []
for seed in seeds:
csv_path = f"./data/{language}_SampledData_{seed}.csv"
if not os.path.exists(csv_path):
return f"File {csv_path} not found! กรุณาเตรียม Dataset ก่อน.", None
df = pd.read_csv(csv_path)
train_dataset, val_dataset, label2id, id2label = prepare_dataset(df, tokenizer)
output_dir = f"./output/{language}/seed{seed}"
final_model_dir = os.path.join(output_dir, "final_model")
Path(final_model_dir).mkdir(parents=True, exist_ok=True)
if not os.path.exists(os.path.join(final_model_dir, "pytorch_model.bin")):
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
use_safetensors=True,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir=output_dir,
seed=seed,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
eval_strategy="epoch",
save_strategy="epoch",
num_train_epochs=5,
fp16=False,
logging_dir=f"./logs/{language}_seed{seed}",
logging_steps=100,
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
model_paths.append(final_model_dir)
metrics = trainer.evaluate()
metric_str = f"Seed {seed} ({language}):\n" + "\n".join(
[f"{k}: {v:.4f}" for k, v in metrics.items()]
)
metric_str_all.append(metric_str)
else:
model_paths.append(final_model_dir)
final_avg_dir = f"./VRM-Emotions/{language}"
Path(final_avg_dir).mkdir(parents=True, exist_ok=True)
return "\n\n".join(metric_str_all), model_paths, final_avg_dir
# --- Async wrapper for training
async def train_model_async(language, progress=gr.Progress(track_tqdm=True)):
return await asyncio.to_thread(train_model, language)
# --- Gradio UI
with gr.Blocks() as demo:
# Tab 1: Options
with gr.Tab("Options"):
language_dropdown = gr.Dropdown(
choices=["Japanese", "Thai"],
label="Language",
value="Japanese"
)
# Tab 2: Prepare Dataset
with gr.Tab("Prepare Dataset"):
dataset_files = gr.Files(label="CSV Files")
sample_btn = gr.Button("Get Datasets")
sample_btn.click(sample_all, inputs=language_dropdown, outputs=dataset_files)
# Tab 3: Train Model
with gr.Tab("Train Model"):
train_results = gr.TextArea(label="Metrics", interactive=False)
models_files = gr.Files(label="Trained Model")
train_btn = gr.Button("Train All")
train_btn.click(train_model_async, inputs=language_dropdown, outputs=[train_results, models_files])
demo.launch(debug=True)