TrainAI / app.py
szili2011's picture
Update app.py
0987ee1 verified
# --- Standard Library Imports ---
import os
import time
import traceback
import tempfile
import json
import math
import collections
import collections.abc # For Gradio compatibility with newer Python versions
# --- UI Framework ---
import gradio as gr
# --- Data Handling & Numerical Ops ---
import pandas as pd
import numpy as np
# --- Core Machine Learning (Scikit-learn) ---
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.svm import SVC, SVR
from sklearn.metrics import accuracy_score, classification_report, mean_squared_error, r2_score
from sklearn.datasets import make_classification, make_regression
import joblib
# --- ONNX Support for Model Interoperability ---
import skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType, StringTensorType
# --- Visualization ---
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for server environments
import matplotlib.pyplot as plt
# --- Graceful ONNX Runtime Handling ---
# This addresses the system-level ImportError on platforms like Hugging Face Spaces.
try:
import onnxruntime as rt
ONNX_RUNTIME_AVAILABLE = True
except ImportError:
ONNX_RUNTIME_AVAILABLE = False
print("Warning: onnxruntime could not be imported. ONNX model validation will be skipped.")
# --- End of Imports ---
# --- Global Variables & Constants ---
TEMP_DIR = "temp_outputs"
os.makedirs(TEMP_DIR, exist_ok=True)
MAX_GENERATED_ROWS = 50000
MAX_GENERATED_COLS = 100
# --- Helper Functions ---
def get_temp_filepath(filename_base, extension):
"""Generates a unique temporary filepath."""
clean_extension = extension.lstrip('.')
return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}")
# --- Dataset and Preprocessing Logic ---
def generate_dataset_backend(task_type, n_samples, n_features, n_classes_or_informative, dataset_format):
"""Generates synthetic data based on user specifications."""
logs = "\n--- Generating Dataset ---\n"
n_samples = max(10, min(int(n_samples), MAX_GENERATED_ROWS))
n_features = max(1, min(int(n_features), MAX_GENERATED_COLS))
n_classes_or_informative = int(n_classes_or_informative)
df = None
try:
if task_type == "Tabular Classification":
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=max(1, n_features // 2),
n_redundant=0, n_classes=max(2, n_classes_or_informative), random_state=42)
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
df['target'] = y
elif task_type == "Tabular Regression":
X, y = make_regression(n_samples=n_samples, n_features=n_features,
n_informative=max(1, min(n_features, n_classes_or_informative)), noise=10, random_state=42)
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
df['target'] = y
if df is None:
raise NotImplementedError(f"Dataset generation for '{task_type}' is not implemented.")
logs += f"Generated data with shape: {df.shape}\n"
file_path = get_temp_filepath("generated_dataset", dataset_format)
if dataset_format == ".csv": df.to_csv(file_path, index=False)
elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True)
elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
logs += f"Dataset saved to temporary file: {os.path.basename(file_path)}\n"
return df.head(), df, logs, file_path
except Exception as e:
error_msg = f"Error generating dataset: {traceback.format_exc()}"
logs += error_msg + "\n"
return None, None, logs, None
# --- Core Training Functions ---
def train_model_sklearn(data_input, target_column, task_type, model_name, model_output_format, logs=""):
"""Handles the entire Scikit-learn training and evaluation pipeline."""
logs += f"\n--- Training Scikit-learn Model: {model_name} ---\n"
try:
# Load data if it's a filepath, otherwise use the DataFrame directly
df = data_input
if isinstance(data_input, str):
if data_input.endswith('.csv'): df = pd.read_csv(data_input)
elif data_input.endswith('.json'): df = pd.read_json(data_input, lines=True)
elif data_input.endswith('.parquet'): df = pd.read_parquet(data_input)
else: raise ValueError("Unsupported file type for upload.")
if target_column not in df.columns:
raise ValueError(f"Target column '{target_column}' not found.")
# Preprocessing
X = df.drop(columns=[target_column])
y = df[target_column]
numeric_features = X.select_dtypes(include=np.number).columns
categorical_features = X.select_dtypes(include='object').columns
preprocessor = ColumnTransformer(transformers=[
('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numeric_features),
('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore'))]), categorical_features)
])
# Model Selection
if task_type == "Tabular Classification":
y = LabelEncoder().fit_transform(y)
models = {
"Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
"Random Forest Classifier": RandomForestClassifier(random_state=42),
"Support Vector Machine (SVM) Classifier": SVC(random_state=42, probability=True)
}
else: # Regression
models = {
"Linear Regression": LinearRegression(),
"Random Forest Regressor": RandomForestRegressor(random_state=42),
"Support Vector Machine (SVR) Regressor": SVR()
}
model = models[model_name]
# Create full pipeline
pipeline = Pipeline([('preprocessor', preprocessor), ('model', model)])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
logs += f"Data split into training ({X_train.shape}) and testing ({X_test.shape}) sets.\n"
# Training
start_time = time.time()
pipeline.fit(X_train, y_train)
logs += f"Training completed in {time.time() - start_time:.2f}s.\n"
# Evaluation
y_pred = pipeline.predict(X_test)
if task_type == "Tabular Classification":
acc = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, zero_division=0)
metrics = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}"
else:
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
metrics = f"Mean Squared Error: {mse:.4f}\nR² Score: {r2:.4f}"
logs += "\n--- Evaluation Metrics ---\n" + metrics + "\n"
# Model Saving
model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}"
model_path = None
if model_output_format == ".pkl (Scikit-learn)":
model_path = get_temp_filepath(model_filename_base, "pkl")
joblib.dump(pipeline, model_path)
logs += f"Model pipeline saved to {os.path.basename(model_path)} as PKL.\n"
elif model_output_format == ".onnx (ONNX)":
model_path = get_temp_filepath(model_filename_base, "onnx")
initial_types = []
for col_name in X.columns:
if pd.api.types.is_numeric_dtype(X[col_name].dtype):
initial_types.append((col_name, FloatTensorType([None, 1])))
else:
initial_types.append((col_name, StringTensorType([None, 1])))
options = {'zipmap': False} if task_type == "Tabular Classification" else {}
onnx_model = convert_sklearn(pipeline, initial_types=initial_types, target_opset=12, options=options)
with open(model_path, "wb") as f: f.write(onnx_model.SerializeToString())
logs += f"Model pipeline saved to {os.path.basename(model_path)} as ONNX.\n"
if ONNX_RUNTIME_AVAILABLE:
sess = rt.InferenceSession(model_path)
logs += "ONNX model successfully loaded and validated with onnxruntime.\n"
else:
logs += "ONNX model validation skipped because onnxruntime is not available in this environment.\n"
return logs, metrics, model_path
except Exception as e:
error_msg = f"Scikit-learn training failed: {traceback.format_exc()}"
logs += error_msg + "\n"
return logs, error_msg, None
# --- Main Training Dispatcher ---
def train_model_wrapper(data_input, target_column, task_type, model_family, model_specific,
model_output_format, logs):
"""A wrapper to call the correct training function based on user choices."""
if data_input is None:
logs += "ERROR: No dataset has been generated or uploaded. Please go to Tab 2.\n"
return logs, "Error: No dataset available.", None, None
if model_family == "Scikit-learn (Classical ML)":
logs, metrics, model_path = train_model_sklearn(data_input, target_column, task_type, model_specific, model_output_format, logs)
return logs, metrics, model_path, None # No plot for sklearn
# Placeholder for future PyTorch integration
else:
logs += f"The selected model family '{model_family}' is not supported yet.\n"
return logs, "Error: Model family not supported.", None, None
# --- Gradio UI Definition ---
def update_model_options(task_choice, model_family_choice):
"""Dynamically updates the available models based on task and family."""
choices = []
if model_family_choice == "Scikit-learn (Classical ML)":
if task_choice == "Tabular Classification":
choices = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"]
elif task_choice == "Tabular Regression":
choices = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"]
value = choices[0] if choices else None
return gr.update(choices=choices, value=value, visible=bool(choices))
def update_model_output_formats(model_family_choice):
"""Updates the output format options based on the model family."""
formats = []
if model_family_choice == "Scikit-learn (Classical ML)":
formats = [".pkl (Scikit-learn)", ".onnx (ONNX)"]
value = formats[0] if formats else None
return gr.update(choices=formats, value=value)
# The Gradio App Layout
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange")) as demo:
gr.Markdown("# 🧠 TrainAI ⚙️")
gr.Markdown("A simple interface to create, train, and download machine learning models.")
# State variables to hold data between interactions
generated_data_state = gr.State(None)
with gr.Tabs():
with gr.TabItem("1. Define Task & Model"):
with gr.Row():
task_type_dd = gr.Dropdown(["Tabular Classification", "Tabular Regression"], label="Select Task Type", value="Tabular Classification")
model_family_dd = gr.Dropdown(["Scikit-learn (Classical ML)"], label="Select Model Family", value="Scikit-learn (Classical ML)")
model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"], value="Logistic Regression", interactive=True)
with gr.TabItem("2. Configure Dataset"):
with gr.Row():
ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
ds_gen_features_num = gr.Number(label="# Features", value=10, minimum=1, step=1)
ds_gen_classes_num = gr.Number(label="Classes (Classif) / Informative (Regr)", value=2, minimum=1, step=1)
ds_gen_format_dd = gr.Dropdown([".csv", ".json", ".parquet"], label="Generated Dataset Format", value=".csv")
generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary")
target_column_name_txt = gr.Textbox(label="Target Column Name", value="target", interactive=True)
# --- FIX: Replaced 'height' with 'row_count' ---
dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, row_count=5)
# --- END FIX ---
generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
with gr.TabItem("3. Train Model & Get Results"):
model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=[".pkl (Scikit-learn)", ".onnx (ONNX)"], value=".pkl (Scikit-learn)")
train_model_btn = gr.Button("🚀 Train Model", variant="primary")
gr.Markdown("---")
gr.Markdown("### Training Progress & Results")
training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50)
evaluation_metrics_txt = gr.Textbox(label="Evaluation Metrics", lines=7, interactive=False)
download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
loss_plot_img = gr.Plot(label="Training Loss Curve (PyTorch only)", visible=False) # Hide as PyTorch is not used
# --- Event Handlers ---
# Update model choices when task or family changes
task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
# Update output formats when family changes
model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
# Dataset generation button
generate_dataset_btn.click(
fn=generate_dataset_backend,
inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_num, ds_gen_format_dd],
outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file]
)
# Main training button
train_model_btn.click(
fn=train_model_wrapper,
inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd, model_output_format_dd, training_log_txt],
outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file, loss_plot_img]
)
# Launch the application
demo.queue().launch(debug=True, show_error=True)