import gradio as gr from typing import Optional, Tuple, Generator, List, Any from config import AppConfig from engine import FunctionGemmaEngine # --- Controller / Logic Layer --- class UIController: """ Handles the business logic and interaction with the Engine. Stateless methods that operate on the passed Engine state. """ @staticmethod def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]: config = AppConfig() new_engine = FunctionGemmaEngine(config) username = profile.username if profile else None # Calculate initial interactivity state repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username) return ( new_engine, new_engine.get_tools_json(), new_engine.config.MODEL_NAME, f"Ready. (Session {new_engine.session_id})", repo_update, push_update, zip_update, username ) @staticmethod def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float, test_size: float, shuffle: bool, model_name: str) -> Generator: if not engine: yield "⚠️ Engine not initialized.", None return engine.config.MODEL_NAME = model_name.strip() yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle) @staticmethod def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator: if not engine: yield "⚠️ Engine not initialized." return engine.config.MODEL_NAME = model_name.strip() yield from engine.run_evaluation(test_size, shuffle) @staticmethod def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str: engine.config.MODEL_NAME = model_name.strip() return engine.refresh_model() @staticmethod def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str: return engine.update_tools(json_val) @staticmethod def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str: return engine.load_csv(file_obj) @staticmethod def stop_process(engine: FunctionGemmaEngine) -> str: engine.trigger_stop() return @staticmethod def zip_model(engine: FunctionGemmaEngine) -> Any: path = engine.get_zip_path() if path: return gr.update(value=path, visible=True) return gr.update(value=None, visible=False) @staticmethod def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str: if oauth_token is None: return "❌ Error: You must log in (top right) to upload models." if not repo_name: return "❌ Error: Please enter a repository name." return engine.upload_model_to_hub( repo_name=repo_name, oauth_token=oauth_token.token, ) @staticmethod def update_repo_preview(username: Optional[str], repo_name: str) -> str: if not username: return "⚠️ Sign in to see the target repository path." clean_repo = repo_name.strip() if repo_name else "..." return f"Target Repository: **`{username}/{clean_repo}`**" @staticmethod def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None): is_logged_in = username is not None has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False) return ( gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned), gr.update(interactive=has_model_tuned) ) # --- View / Layout Layer --- def _render_header(): with gr.Column(): gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning") gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.
" "See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.") gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).
⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).") with gr.Row(): gr.LoginButton(value="Sign in with Hugging Face") with gr.Column(scale=3): gr.Markdown("") def _render_dataset_tab(engine_state): with gr.TabItem("1. Preparing Dataset"): gr.Markdown("### 🛠️ Tool Schema & Data Import") gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.
The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:") gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`") gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("**Step 1: Define Functions**
Edit the JSON schema below to define the tools the model should learn.") tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15) update_tools_btn = gr.Button("💾 Update Tool Schema") tools_status = gr.Markdown("") with gr.Column(scale=1): gr.Markdown("**Step 2: Upload Data (Optional)**
To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).") gr.Markdown("**Example CSV Row:** No header required.
Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100) import_status = gr.Markdown("") # Return controls needed for wiring return { "tools_editor": tools_editor, "update_tools_btn": update_tools_btn, "tools_status": tools_status, "import_file": import_file, "import_status": import_status } def _render_training_tab(engine_state): with gr.TabItem("2. Training & Eval"): gr.Markdown("### 🚀 Fine-Tuning Configuration") with gr.Group(): gr.Markdown("**Hyperparameters**") with gr.Row(): default_models = AppConfig().AVAILABLE_MODELS param_model = gr.Dropdown( choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True ) param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes") with gr.Row(): param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5") param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)") param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split") with gr.Row(): run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1) stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1) run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1) clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1) with gr.Row(): output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True) loss_plot = gr.Plot(label="Training Metrics") return { "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model], "eval_params": [param_test_size, param_shuffle, param_model], "buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn], "outputs": [output_display, loss_plot], "model_input": param_model # specifically needed for initialization } def _render_export_tab(engine_state, username_state): with gr.TabItem("3. Export"): gr.Markdown("### 📦 Export Trained Model") with gr.Row(): with gr.Column(): gr.Markdown("#### Option A: Download ZIP") gr.Markdown("Download the model weights locally.") zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False) download_file = gr.File(label="Download Archive", interactive=False) gr.Markdown("NOTE: Zipping usually takes 1~2 min.") with gr.Column(): gr.Markdown("#### Option B: Save to Hugging Face Hub") gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.") repo_name_input = gr.Textbox( label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False ) push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False) repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)") upload_status = gr.Markdown("") return { "zip_controls": [zip_btn, download_file], "hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status] } # --- Main Build Function --- def build_interface() -> gr.Blocks: with gr.Blocks(title="FunctionGemma Tuning Lab") as demo: engine_state = gr.State() username_state = gr.State() _render_header() with gr.Tabs(): data_ui = _render_dataset_tab(engine_state) train_ui = _render_training_tab(engine_state) export_ui = _render_export_tab(engine_state, username_state) # Helpers for UI State # 'action_buttons' now ONLY contains buttons that should always be enabled after a process # Zip and Push buttons are excluded here because their state depends on has_model_tuned run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"] action_buttons = [reload_btn, run_btn, eval_btn] repo_input = export_ui["hub_controls"][0] push_btn = export_ui["hub_controls"][1] zip_btn = export_ui["zip_controls"][0] def lock_ui(): """Locks all buttons (including Zip/Push) during processing""" return [gr.update(interactive=False) for _ in action_buttons] + \ [gr.update(interactive=False), gr.update(interactive=False)] def unlock_ui(): """Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive""" return [gr.update(interactive=True) for _ in action_buttons] # --- Event Wiring --- # 1. Initialization demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( fn=UIController.init_session, inputs=None, outputs=[ engine_state, data_ui["tools_editor"], train_ui["model_input"], train_ui["outputs"][0], # log output repo_input, push_btn, zip_btn, # Update Zip state based on initial engine state username_state ] ).then( fn=UIController.update_repo_preview, inputs=[username_state, repo_input], outputs=[export_ui["hub_controls"][2]] ).then(unlock_ui, outputs=action_buttons) # 2. Data Tab data_ui["update_tools_btn"].click( fn=UIController.update_tools, inputs=[engine_state, data_ui["tools_editor"]], outputs=[data_ui["tools_status"]] ) data_ui["import_file"].upload( fn=UIController.import_file, inputs=[engine_state, data_ui["import_file"]], outputs=[data_ui["import_status"]] ) # 3. Training & Eval Tab # 3a. Training train_run_event = run_btn.click( fn=lambda: ( gr.update(visible=False), gr.update(interactive=False), # Lock Reload gr.update(interactive=False), # Lock Eval gr.update(interactive=False), # Lock Zip gr.update(visible=True) # Show Stop ), outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn] ) train_run_event = train_run_event.then( fn=UIController.run_training, inputs=[engine_state, *train_ui["params"]], outputs=train_ui["outputs"], ).then( fn=lambda: ( gr.update(visible=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(visible=False) ), outputs=[run_btn, reload_btn, eval_btn, stop_btn] ).then( # Final check determines if Zip/Push should unlock fn=UIController.update_hub_interactive, inputs=[engine_state, username_state], outputs=[repo_input, push_btn, zip_btn] ) # 3b. Evaluation eval_run_event = eval_btn.click( fn=lambda: ( gr.update(interactive=False), # Lock Run gr.update(interactive=False), # Lock Reload gr.update(visible=False), # Hide self (optional, or lock) gr.update(visible=True) # Show Stop ), outputs=[run_btn, reload_btn, eval_btn, stop_btn] ) eval_run_event = eval_run_event.then( fn=UIController.run_evaluation, inputs=[engine_state, *train_ui["eval_params"]], outputs=[train_ui["outputs"][0]] # Output only to log, not plot ).then( fn=lambda: ( gr.update(interactive=True), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False) ), outputs=[run_btn, reload_btn, eval_btn, stop_btn] ) stop_btn.click( fn=UIController.stop_process, inputs=[engine_state], cancels=[train_run_event, eval_run_event], outputs=None, queue=False ) reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( fn=UIController.handle_reset, inputs=[engine_state, train_ui["model_input"]], outputs=[train_ui["outputs"][0]] ).then(unlock_ui, outputs=action_buttons).then( fn=UIController.update_hub_interactive, inputs=[engine_state, username_state], outputs=[repo_input, push_btn, zip_btn] ) # 4. Export Tab zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( fn=UIController.zip_model, inputs=[engine_state], outputs=[export_ui["zip_controls"][1]] ).then(unlock_ui, outputs=action_buttons).then( fn=UIController.update_hub_interactive, inputs=[engine_state, username_state], outputs=[repo_input, push_btn, zip_btn] ) repo_input.change( fn=UIController.update_repo_preview, inputs=[username_state, repo_input], outputs=[export_ui["hub_controls"][2]] ) push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( fn=UIController.upload_model, inputs=[engine_state, repo_input], outputs=[export_ui["hub_controls"][3]] ).then(unlock_ui, outputs=action_buttons).then( fn=UIController.update_hub_interactive, inputs=[engine_state, username_state], outputs=[repo_input, push_btn, zip_btn] ) return demo