import sys import os import gradio as gr import json import glob from tinytroupe.factory import TinyPersonFactory from tinytroupe.utils.semantics import select_best_persona, select_relevant_personas_utility from tinytroupe.simulation_manager import SimulationManager, SimulationConfig from tinytroupe.agent.social_types import Content from huggingface_hub import hf_hub_download, upload_file HF_TOKEN = os.getenv("HF_TOKEN") # Ensure this is set in Space secrets REPO_ID = "AUXteam/tiny_factory" PERSONA_BASE_FILE = "persona_base.json" simulation_manager = SimulationManager() def load_persona_base(): if not HF_TOKEN: print("HF_TOKEN not found, persistence disabled.") return [] try: path = hf_hub_download(repo_id=REPO_ID, filename=PERSONA_BASE_FILE, repo_type="space", token=HF_TOKEN) with open(path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"Error loading persona base: {e}") return [] def save_persona_base(personas): if not HF_TOKEN: print("HF_TOKEN not found, skipping upload.") return with open(PERSONA_BASE_FILE, 'w', encoding='utf-8') as f: json.dump(personas, f, indent=4) try: upload_file( path_or_fileobj=PERSONA_BASE_FILE, path_in_repo=PERSONA_BASE_FILE, repo_id=REPO_ID, repo_type="space", token=HF_TOKEN ) except Exception as e: print(f"Error saving persona base to Hub: {e}") # --- CHANGE 1: The function now accepts an optional API key. --- def generate_personas(business_description, customer_profile, num_personas, blablador_api_key=None): """ Generates a list of TinyPerson instances based on the provided inputs. It prioritizes the API key passed as an argument, but falls back to the environment variable if none is provided (for UI use). """ # --- CHANGE 2: Logic to determine which key to use. --- # Use the key from the API call if provided, otherwise get it from the Space secrets. api_key_to_use = blablador_api_key or os.getenv("BLABLADOR_API_KEY") if not api_key_to_use: return {"error": "BLABLADOR_API_KEY not found. Please provide it in your API call or set it as a secret in the Space settings."} # Store the original state of the environment variable, if it exists original_key = os.getenv("BLABLADOR_API_KEY") try: # --- CHANGE 3: Securely set the correct environment variable for this request. --- # The underlying tinytroupe library will look for this variable. os.environ["BLABLADOR_API_KEY"] = api_key_to_use num_personas = int(num_personas) factory = TinyPersonFactory( context=business_description, sampling_space_description=customer_profile, total_population_size=num_personas ) people = factory.generate_people(number_of_people=num_personas, parallelize=False) personas_data = [person._persona for person in people] # --- NEW: Update the Tresor --- current_base = load_persona_base() current_base.extend(personas_data) save_persona_base(current_base) # ------------------------------ return personas_data except Exception as e: return {"error": str(e)} finally: # --- CHANGE 4: A robust cleanup using a 'finally' block. --- # This ensures the environment is always restored to its original state, # whether the function succeeds or fails. if original_key is None: # If the variable didn't exist originally, remove it. if "BLABLADOR_API_KEY" in os.environ: del os.environ["BLABLADOR_API_KEY"] else: # If it existed, restore its original value. os.environ["BLABLADOR_API_KEY"] = original_key def find_best_persona(criteria): """ Loads the persona base and finds the best matching persona based on criteria. """ personas = load_persona_base() if not personas: return {"error": "Persona base is empty. Generate some personas first!"} try: # select_best_persona uses LLM to find the best index idx = select_best_persona(criteria=criteria, personas=personas) try: idx = int(idx) except (ValueError, TypeError): return {"error": f"LLM returned an invalid index: {idx}"} if idx >= 0 and idx < len(personas): return personas[idx] else: return {"error": f"No matching persona found for criteria: {criteria}"} except Exception as e: return {"error": f"Error during persona matching: {str(e)}"} def load_example_personas(): """ Loads example personas from the tinytroupe library. """ example_personas = [] # Path to the agents folder in tinytroupe/examples agents_path = os.path.join("tinytroupe", "examples", "agents", "*.agent.json") for file_path in glob.glob(agents_path): try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if "persona" in data: example_personas.append(data["persona"]) except Exception as e: print(f"Error loading example persona from {file_path}: {e}") return example_personas def identify_personas(context): """ Identifies appropriate personas from the Tresor and example agents based on context. """ try: # 1. Load Tresor personas (persisted JSON) tresor_personas = load_persona_base() # 2. Load Example personas from tinytroupe library example_personas = load_example_personas() all_available = tresor_personas + example_personas if not all_available: return {"error": "No personas available in Tresor or examples."} # 3. Use LLM to filter/select which ones match the 'context' # Returns a list of indices indices = select_relevant_personas_utility(context, all_available) selected = [] if isinstance(indices, list): for i in indices: try: idx = int(i) if 0 <= idx < len(all_available): selected.append(all_available[idx]) except (ValueError, TypeError): continue return selected except Exception as e: return {"error": str(e)} def generate_social_network_api(name, persona_count, network_type, focus_group_name=None): """ Gradio API endpoint for generating a social network. """ try: config = SimulationConfig(name=name, persona_count=int(persona_count), network_type=network_type) simulation = simulation_manager.create_simulation(config, focus_group_name=focus_group_name) return { "simulation_id": simulation.id, "name": simulation.config.name, "persona_count": len(simulation.personas), "network_metrics": simulation.network.get_metrics() } except Exception as e: return {"error": str(e)} def predict_engagement_api(simulation_id, content_text, format="text"): """ Gradio API endpoint for predicting engagement. """ try: content = Content(text=content_text, format=format) result = simulation_manager.run_simulation(simulation_id, content) return { "total_reach": result.total_reach, "expected_likes": result.expected_likes, "expected_comments": result.expected_comments, "expected_shares": result.expected_shares, "execution_time": result.execution_time, "avg_sentiment": result.avg_sentiment, "feedback_summary": result.feedback_summary } except Exception as e: return {"error": str(e)} def start_simulation_async_api(simulation_id, content_text, format="text"): """ Starts a simulation in the background. """ try: content = Content(text=content_text, format=format) simulation_manager.run_simulation(simulation_id, content, background=True) return {"status": "started", "simulation_id": simulation_id} except Exception as e: return {"error": str(e)} def get_simulation_status_api(simulation_id): """ Checks the status and progress of a simulation. """ try: sim = simulation_manager.get_simulation(simulation_id) if not sim: return {"error": "Simulation not found"} status_data = { "status": sim.status, "progress": sim.progress } if sim.status == "completed" and sim.last_result: status_data["result"] = { "total_reach": sim.last_result.total_reach, "expected_likes": sim.last_result.expected_likes, "avg_sentiment": sim.last_result.avg_sentiment } return status_data except Exception as e: return {"error": str(e)} def send_chat_message_api(simulation_id, sender, message): """ Sends a message to the simulation chat. """ try: return simulation_manager.send_chat_message(simulation_id, sender, message) except Exception as e: return {"error": str(e)} def get_chat_history_api(simulation_id): """ Gets the chat history for a simulation. """ try: return simulation_manager.get_chat_history(simulation_id) except Exception as e: return {"error": str(e)} def generate_variants_api(content_text, num_variants): """ Gradio API endpoint for generating content variants. """ try: variants = simulation_manager.variant_generator.generate_variants(content_text, num_variants=int(num_variants)) return [{"text": v.text, "strategy": v.strategy} for v in variants] except Exception as e: return {"error": str(e)} def list_simulations_api(): """ Gradio API endpoint for listing simulations. """ try: return simulation_manager.list_simulations() except Exception as e: return {"error": str(e)} def list_personas_api(simulation_id): """ Gradio API endpoint for listing personas in a simulation. """ try: return simulation_manager.list_personas(simulation_id) except Exception as e: return {"error": str(e)} def get_persona_api(simulation_id, persona_name): """ Gradio API endpoint for getting persona details. """ try: return simulation_manager.get_persona(simulation_id, persona_name) except Exception as e: return {"error": str(e)} def delete_simulation_api(simulation_id): """ Gradio API endpoint for deleting a simulation. """ try: success = simulation_manager.delete_simulation(simulation_id) return {"success": success} except Exception as e: return {"error": str(e)} def export_simulation_api(simulation_id): """ Gradio API endpoint for exporting a simulation. """ try: return simulation_manager.export_simulation(simulation_id) except Exception as e: return {"error": str(e)} def get_network_graph_api(simulation_id): """ Gradio API endpoint for getting network graph data. """ try: sim = simulation_manager.get_simulation(simulation_id) if not sim: return {"error": "Simulation not found"} nodes = [] for p in sim.personas: nodes.append({ "id": p.name, "label": p.name, "role": p._persona.get("occupation"), "location": p._persona.get("residence") }) edges = [] for edge in sim.network.edges: edges.append({ "source": edge.connection_id.split('_')[0], "target": edge.connection_id.split('_')[1], "strength": edge.strength }) return {"nodes": nodes, "edges": edges} except Exception as e: return {"error": str(e)} def list_focus_groups_api(): """ Gradio API endpoint for listing focus groups. """ try: return simulation_manager.list_focus_groups() except Exception as e: return {"error": str(e)} def save_focus_group_api(name, simulation_id): """ Gradio API endpoint for saving a focus group from a simulation. """ try: sim = simulation_manager.get_simulation(simulation_id) if not sim: return {"error": "Simulation not found"} simulation_manager.save_focus_group(name, sim.personas) return {"status": "success", "name": name} except Exception as e: return {"error": str(e)} with gr.Blocks() as demo: gr.Markdown("