Spaces:
Running
Running
| import gradio as gr | |
| from typing import Optional, Tuple, Generator, List, Any | |
| from config import AppConfig | |
| from engine import FunctionGemmaEngine | |
| # --- Controller / Logic Layer --- | |
| class UIController: | |
| """ | |
| Handles the business logic and interaction with the Engine. | |
| Stateless methods that operate on the passed Engine state. | |
| """ | |
| def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]: | |
| config = AppConfig() | |
| new_engine = FunctionGemmaEngine(config) | |
| username = profile.username if profile else None | |
| # Calculate initial interactivity state | |
| repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username) | |
| return ( | |
| new_engine, | |
| new_engine.get_tools_json(), | |
| new_engine.config.MODEL_NAME, | |
| f"Ready. (Session {new_engine.session_id})", | |
| repo_update, | |
| push_update, | |
| zip_update, | |
| username | |
| ) | |
| def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float, | |
| test_size: float, shuffle: bool, model_name: str) -> Generator: | |
| if not engine: | |
| yield "⚠️ Engine not initialized.", None | |
| return | |
| engine.config.MODEL_NAME = model_name.strip() | |
| yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle) | |
| def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator: | |
| if not engine: | |
| yield "⚠️ Engine not initialized." | |
| return | |
| engine.config.MODEL_NAME = model_name.strip() | |
| yield from engine.run_evaluation(test_size, shuffle) | |
| def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str: | |
| engine.config.MODEL_NAME = model_name.strip() | |
| return engine.refresh_model() | |
| def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str: | |
| return engine.update_tools(json_val) | |
| def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str: | |
| return engine.load_csv(file_obj) | |
| def stop_process(engine: FunctionGemmaEngine) -> str: | |
| engine.trigger_stop() | |
| return | |
| def zip_model(engine: FunctionGemmaEngine) -> Any: | |
| path = engine.get_zip_path() | |
| if path: | |
| return gr.update(value=path, visible=True) | |
| return gr.update(value=None, visible=False) | |
| def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str: | |
| if oauth_token is None: | |
| return "❌ Error: You must log in (top right) to upload models." | |
| if not repo_name: | |
| return "❌ Error: Please enter a repository name." | |
| return engine.upload_model_to_hub( | |
| repo_name=repo_name, | |
| oauth_token=oauth_token.token, | |
| ) | |
| def update_repo_preview(username: Optional[str], repo_name: str) -> str: | |
| if not username: | |
| return "⚠️ Sign in to see the target repository path." | |
| clean_repo = repo_name.strip() if repo_name else "..." | |
| return f"Target Repository: **`{username}/{clean_repo}`**" | |
| def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None): | |
| is_logged_in = username is not None | |
| has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False) | |
| return ( | |
| gr.update(interactive=is_logged_in), | |
| gr.update(interactive=is_logged_in and has_model_tuned), | |
| gr.update(interactive=has_model_tuned) | |
| ) | |
| # --- View / Layout Layer --- | |
| def _render_header(): | |
| with gr.Column(): | |
| gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning") | |
| gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>" | |
| "See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.") | |
| gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).<br>⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).") | |
| with gr.Row(): | |
| gr.LoginButton(value="Sign in with Hugging Face") | |
| with gr.Column(scale=3): | |
| gr.Markdown("") | |
| def _render_dataset_tab(engine_state): | |
| with gr.TabItem("1. Preparing Dataset"): | |
| gr.Markdown("### 🛠️ Tool Schema & Data Import") | |
| gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.<br>The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:") | |
| gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`") | |
| gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.") | |
| tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15) | |
| update_tools_btn = gr.Button("💾 Update Tool Schema") | |
| tools_status = gr.Markdown("") | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).") | |
| gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") | |
| import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100) | |
| import_status = gr.Markdown("") | |
| # Return controls needed for wiring | |
| return { | |
| "tools_editor": tools_editor, | |
| "update_tools_btn": update_tools_btn, | |
| "tools_status": tools_status, | |
| "import_file": import_file, | |
| "import_status": import_status | |
| } | |
| def _render_training_tab(engine_state): | |
| with gr.TabItem("2. Training & Eval"): | |
| gr.Markdown("### 🚀 Fine-Tuning Configuration") | |
| with gr.Group(): | |
| gr.Markdown("**Hyperparameters**") | |
| with gr.Row(): | |
| default_models = AppConfig().AVAILABLE_MODELS | |
| param_model = gr.Dropdown( | |
| choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True | |
| ) | |
| param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes") | |
| with gr.Row(): | |
| param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5") | |
| param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)") | |
| param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split") | |
| with gr.Row(): | |
| run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1) | |
| stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1) | |
| run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1) | |
| clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1) | |
| with gr.Row(): | |
| output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True) | |
| loss_plot = gr.Plot(label="Training Metrics") | |
| return { | |
| "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model], | |
| "eval_params": [param_test_size, param_shuffle, param_model], | |
| "buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn], | |
| "outputs": [output_display, loss_plot], | |
| "model_input": param_model # specifically needed for initialization | |
| } | |
| def _render_export_tab(engine_state, username_state): | |
| with gr.TabItem("3. Export"): | |
| gr.Markdown("### 📦 Export Trained Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Option A: Download ZIP") | |
| gr.Markdown("Download the model weights locally.") | |
| zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False) | |
| download_file = gr.File(label="Download Archive", interactive=False) | |
| gr.Markdown("NOTE: Zipping usually takes 1~2 min.") | |
| with gr.Column(): | |
| gr.Markdown("#### Option B: Save to Hugging Face Hub") | |
| gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.") | |
| repo_name_input = gr.Textbox( | |
| label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False | |
| ) | |
| push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False) | |
| repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)") | |
| upload_status = gr.Markdown("") | |
| return { | |
| "zip_controls": [zip_btn, download_file], | |
| "hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status] | |
| } | |
| # --- Main Build Function --- | |
| def build_interface() -> gr.Blocks: | |
| with gr.Blocks(title="FunctionGemma Tuning Lab") as demo: | |
| engine_state = gr.State() | |
| username_state = gr.State() | |
| _render_header() | |
| with gr.Tabs(): | |
| data_ui = _render_dataset_tab(engine_state) | |
| train_ui = _render_training_tab(engine_state) | |
| export_ui = _render_export_tab(engine_state, username_state) | |
| # Helpers for UI State | |
| # 'action_buttons' now ONLY contains buttons that should always be enabled after a process | |
| # Zip and Push buttons are excluded here because their state depends on has_model_tuned | |
| run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"] | |
| action_buttons = [reload_btn, run_btn, eval_btn] | |
| repo_input = export_ui["hub_controls"][0] | |
| push_btn = export_ui["hub_controls"][1] | |
| zip_btn = export_ui["zip_controls"][0] | |
| def lock_ui(): | |
| """Locks all buttons (including Zip/Push) during processing""" | |
| return [gr.update(interactive=False) for _ in action_buttons] + \ | |
| [gr.update(interactive=False), gr.update(interactive=False)] | |
| def unlock_ui(): | |
| """Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive""" | |
| return [gr.update(interactive=True) for _ in action_buttons] | |
| # --- Event Wiring --- | |
| # 1. Initialization | |
| demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( | |
| fn=UIController.init_session, | |
| inputs=None, | |
| outputs=[ | |
| engine_state, | |
| data_ui["tools_editor"], | |
| train_ui["model_input"], | |
| train_ui["outputs"][0], # log output | |
| repo_input, | |
| push_btn, | |
| zip_btn, # Update Zip state based on initial engine state | |
| username_state | |
| ] | |
| ).then( | |
| fn=UIController.update_repo_preview, | |
| inputs=[username_state, repo_input], | |
| outputs=[export_ui["hub_controls"][2]] | |
| ).then(unlock_ui, outputs=action_buttons) | |
| # 2. Data Tab | |
| data_ui["update_tools_btn"].click( | |
| fn=UIController.update_tools, | |
| inputs=[engine_state, data_ui["tools_editor"]], | |
| outputs=[data_ui["tools_status"]] | |
| ) | |
| data_ui["import_file"].upload( | |
| fn=UIController.import_file, | |
| inputs=[engine_state, data_ui["import_file"]], | |
| outputs=[data_ui["import_status"]] | |
| ) | |
| # 3. Training & Eval Tab | |
| # 3a. Training | |
| train_run_event = run_btn.click( | |
| fn=lambda: ( | |
| gr.update(visible=False), | |
| gr.update(interactive=False), # Lock Reload | |
| gr.update(interactive=False), # Lock Eval | |
| gr.update(interactive=False), # Lock Zip | |
| gr.update(visible=True) # Show Stop | |
| ), | |
| outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn] | |
| ) | |
| train_run_event = train_run_event.then( | |
| fn=UIController.run_training, | |
| inputs=[engine_state, *train_ui["params"]], | |
| outputs=train_ui["outputs"], | |
| ).then( | |
| fn=lambda: ( | |
| gr.update(visible=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(visible=False) | |
| ), | |
| outputs=[run_btn, reload_btn, eval_btn, stop_btn] | |
| ).then( | |
| # Final check determines if Zip/Push should unlock | |
| fn=UIController.update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_input, push_btn, zip_btn] | |
| ) | |
| # 3b. Evaluation | |
| eval_run_event = eval_btn.click( | |
| fn=lambda: ( | |
| gr.update(interactive=False), # Lock Run | |
| gr.update(interactive=False), # Lock Reload | |
| gr.update(visible=False), # Hide self (optional, or lock) | |
| gr.update(visible=True) # Show Stop | |
| ), | |
| outputs=[run_btn, reload_btn, eval_btn, stop_btn] | |
| ) | |
| eval_run_event = eval_run_event.then( | |
| fn=UIController.run_evaluation, | |
| inputs=[engine_state, *train_ui["eval_params"]], | |
| outputs=[train_ui["outputs"][0]] # Output only to log, not plot | |
| ).then( | |
| fn=lambda: ( | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ), | |
| outputs=[run_btn, reload_btn, eval_btn, stop_btn] | |
| ) | |
| stop_btn.click( | |
| fn=UIController.stop_process, | |
| inputs=[engine_state], | |
| cancels=[train_run_event, eval_run_event], | |
| outputs=None, | |
| queue=False | |
| ) | |
| reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( | |
| fn=UIController.handle_reset, | |
| inputs=[engine_state, train_ui["model_input"]], | |
| outputs=[train_ui["outputs"][0]] | |
| ).then(unlock_ui, outputs=action_buttons).then( | |
| fn=UIController.update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_input, push_btn, zip_btn] | |
| ) | |
| # 4. Export Tab | |
| zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( | |
| fn=UIController.zip_model, | |
| inputs=[engine_state], | |
| outputs=[export_ui["zip_controls"][1]] | |
| ).then(unlock_ui, outputs=action_buttons).then( | |
| fn=UIController.update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_input, push_btn, zip_btn] | |
| ) | |
| repo_input.change( | |
| fn=UIController.update_repo_preview, | |
| inputs=[username_state, repo_input], | |
| outputs=[export_ui["hub_controls"][2]] | |
| ) | |
| push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( | |
| fn=UIController.upload_model, | |
| inputs=[engine_state, repo_input], | |
| outputs=[export_ui["hub_controls"][3]] | |
| ).then(unlock_ui, outputs=action_buttons).then( | |
| fn=UIController.update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_input, push_btn, zip_btn] | |
| ) | |
| return demo |