import gradio as gr
from typing import Optional, Tuple, Generator, List, Any
from config import AppConfig
from engine import FunctionGemmaEngine
# --- Controller / Logic Layer ---
class UIController:
"""
Handles the business logic and interaction with the Engine.
Stateless methods that operate on the passed Engine state.
"""
@staticmethod
def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]:
config = AppConfig()
new_engine = FunctionGemmaEngine(config)
username = profile.username if profile else None
# Calculate initial interactivity state
repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username)
return (
new_engine,
new_engine.get_tools_json(),
new_engine.config.MODEL_NAME,
f"Ready. (Session {new_engine.session_id})",
repo_update,
push_update,
zip_update,
username
)
@staticmethod
def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float,
test_size: float, shuffle: bool, model_name: str) -> Generator:
if not engine:
yield "⚠️ Engine not initialized.", None
return
engine.config.MODEL_NAME = model_name.strip()
yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
@staticmethod
def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator:
if not engine:
yield "⚠️ Engine not initialized."
return
engine.config.MODEL_NAME = model_name.strip()
yield from engine.run_evaluation(test_size, shuffle)
@staticmethod
def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str:
engine.config.MODEL_NAME = model_name.strip()
return engine.refresh_model()
@staticmethod
def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str:
return engine.update_tools(json_val)
@staticmethod
def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str:
return engine.load_csv(file_obj)
@staticmethod
def stop_process(engine: FunctionGemmaEngine) -> str:
engine.trigger_stop()
return
@staticmethod
def zip_model(engine: FunctionGemmaEngine) -> Any:
path = engine.get_zip_path()
if path:
return gr.update(value=path, visible=True)
return gr.update(value=None, visible=False)
@staticmethod
def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str:
if oauth_token is None:
return "❌ Error: You must log in (top right) to upload models."
if not repo_name:
return "❌ Error: Please enter a repository name."
return engine.upload_model_to_hub(
repo_name=repo_name,
oauth_token=oauth_token.token,
)
@staticmethod
def update_repo_preview(username: Optional[str], repo_name: str) -> str:
if not username:
return "⚠️ Sign in to see the target repository path."
clean_repo = repo_name.strip() if repo_name else "..."
return f"Target Repository: **`{username}/{clean_repo}`**"
@staticmethod
def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None):
is_logged_in = username is not None
has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False)
return (
gr.update(interactive=is_logged_in),
gr.update(interactive=is_logged_in and has_model_tuned),
gr.update(interactive=has_model_tuned)
)
# --- View / Layout Layer ---
def _render_header():
with gr.Column():
gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning")
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.
"
"See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.")
gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).
⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).")
with gr.Row():
gr.LoginButton(value="Sign in with Hugging Face")
with gr.Column(scale=3):
gr.Markdown("")
def _render_dataset_tab(engine_state):
with gr.TabItem("1. Preparing Dataset"):
gr.Markdown("### 🛠️ Tool Schema & Data Import")
gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.
The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:")
gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`")
gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1: Define Functions**
Edit the JSON schema below to define the tools the model should learn.")
tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15)
update_tools_btn = gr.Button("💾 Update Tool Schema")
tools_status = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("**Step 2: Upload Data (Optional)**
To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).")
gr.Markdown("**Example CSV Row:** No header required.
Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100)
import_status = gr.Markdown("")
# Return controls needed for wiring
return {
"tools_editor": tools_editor,
"update_tools_btn": update_tools_btn,
"tools_status": tools_status,
"import_file": import_file,
"import_status": import_status
}
def _render_training_tab(engine_state):
with gr.TabItem("2. Training & Eval"):
gr.Markdown("### 🚀 Fine-Tuning Configuration")
with gr.Group():
gr.Markdown("**Hyperparameters**")
with gr.Row():
default_models = AppConfig().AVAILABLE_MODELS
param_model = gr.Dropdown(
choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True
)
param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes")
with gr.Row():
param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5")
param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)")
param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split")
with gr.Row():
run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1)
stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1)
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1)
with gr.Row():
output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True)
loss_plot = gr.Plot(label="Training Metrics")
return {
"params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model],
"eval_params": [param_test_size, param_shuffle, param_model],
"buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn],
"outputs": [output_display, loss_plot],
"model_input": param_model # specifically needed for initialization
}
def _render_export_tab(engine_state, username_state):
with gr.TabItem("3. Export"):
gr.Markdown("### 📦 Export Trained Model")
with gr.Row():
with gr.Column():
gr.Markdown("#### Option A: Download ZIP")
gr.Markdown("Download the model weights locally.")
zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
download_file = gr.File(label="Download Archive", interactive=False)
gr.Markdown("NOTE: Zipping usually takes 1~2 min.")
with gr.Column():
gr.Markdown("#### Option B: Save to Hugging Face Hub")
gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
repo_name_input = gr.Textbox(
label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False
)
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
upload_status = gr.Markdown("")
return {
"zip_controls": [zip_btn, download_file],
"hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status]
}
# --- Main Build Function ---
def build_interface() -> gr.Blocks:
with gr.Blocks(title="FunctionGemma Tuning Lab") as demo:
engine_state = gr.State()
username_state = gr.State()
_render_header()
with gr.Tabs():
data_ui = _render_dataset_tab(engine_state)
train_ui = _render_training_tab(engine_state)
export_ui = _render_export_tab(engine_state, username_state)
# Helpers for UI State
# 'action_buttons' now ONLY contains buttons that should always be enabled after a process
# Zip and Push buttons are excluded here because their state depends on has_model_tuned
run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"]
action_buttons = [reload_btn, run_btn, eval_btn]
repo_input = export_ui["hub_controls"][0]
push_btn = export_ui["hub_controls"][1]
zip_btn = export_ui["zip_controls"][0]
def lock_ui():
"""Locks all buttons (including Zip/Push) during processing"""
return [gr.update(interactive=False) for _ in action_buttons] + \
[gr.update(interactive=False), gr.update(interactive=False)]
def unlock_ui():
"""Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive"""
return [gr.update(interactive=True) for _ in action_buttons]
# --- Event Wiring ---
# 1. Initialization
demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.init_session,
inputs=None,
outputs=[
engine_state,
data_ui["tools_editor"],
train_ui["model_input"],
train_ui["outputs"][0], # log output
repo_input,
push_btn,
zip_btn, # Update Zip state based on initial engine state
username_state
]
).then(
fn=UIController.update_repo_preview,
inputs=[username_state, repo_input],
outputs=[export_ui["hub_controls"][2]]
).then(unlock_ui, outputs=action_buttons)
# 2. Data Tab
data_ui["update_tools_btn"].click(
fn=UIController.update_tools,
inputs=[engine_state, data_ui["tools_editor"]],
outputs=[data_ui["tools_status"]]
)
data_ui["import_file"].upload(
fn=UIController.import_file,
inputs=[engine_state, data_ui["import_file"]],
outputs=[data_ui["import_status"]]
)
# 3. Training & Eval Tab
# 3a. Training
train_run_event = run_btn.click(
fn=lambda: (
gr.update(visible=False),
gr.update(interactive=False), # Lock Reload
gr.update(interactive=False), # Lock Eval
gr.update(interactive=False), # Lock Zip
gr.update(visible=True) # Show Stop
),
outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn]
)
train_run_event = train_run_event.then(
fn=UIController.run_training,
inputs=[engine_state, *train_ui["params"]],
outputs=train_ui["outputs"],
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=False)
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
).then(
# Final check determines if Zip/Push should unlock
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
# 3b. Evaluation
eval_run_event = eval_btn.click(
fn=lambda: (
gr.update(interactive=False), # Lock Run
gr.update(interactive=False), # Lock Reload
gr.update(visible=False), # Hide self (optional, or lock)
gr.update(visible=True) # Show Stop
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
)
eval_run_event = eval_run_event.then(
fn=UIController.run_evaluation,
inputs=[engine_state, *train_ui["eval_params"]],
outputs=[train_ui["outputs"][0]] # Output only to log, not plot
).then(
fn=lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=True),
gr.update(visible=False)
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
)
stop_btn.click(
fn=UIController.stop_process,
inputs=[engine_state],
cancels=[train_run_event, eval_run_event],
outputs=None,
queue=False
)
reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.handle_reset,
inputs=[engine_state, train_ui["model_input"]],
outputs=[train_ui["outputs"][0]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
# 4. Export Tab
zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.zip_model,
inputs=[engine_state],
outputs=[export_ui["zip_controls"][1]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
repo_input.change(
fn=UIController.update_repo_preview,
inputs=[username_state, repo_input],
outputs=[export_ui["hub_controls"][2]]
)
push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.upload_model,
inputs=[engine_state, repo_input],
outputs=[export_ui["hub_controls"][3]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
return demo