Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import tempfile | |
| import shutil | |
| from functools import partial | |
| import traceback # <--- ADDED THIS LINE: Import the traceback module | |
| from diffusers import StableDiffusionPipeline | |
| from huggingface_hub import InferenceClient | |
| # LangChain imports | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.tools import tool | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain.agents import AgentExecutor, create_react_agent | |
| from langchain.schema import HumanMessage, AIMessage | |
| # --- 1. Load Stable Diffusion Pipeline (happens once at startup) --- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM | |
| # Define the model ID for image generation | |
| IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully | |
| print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...") | |
| try: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| IMAGE_GEN_MODEL_ID, | |
| torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4 | |
| use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd) | |
| token=HF_TOKEN # Pass token for potential faster model download | |
| ) | |
| pipe.to("cuda") # Move the model to the GPU | |
| print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.") | |
| except Exception as e: | |
| print("β Error loading Stable Diffusion Pipeline:") | |
| traceback.print_exc() | |
| pipe = None # Indicate failure to load | |
| # --- 2. Define Custom Image Generation Tool for LangChain --- | |
| # Use @tool decorator to make a function a LangChain tool | |
| def image_generator(prompt: str) -> str: | |
| """ | |
| Generates an image from a detailed text prompt using a Stable Diffusion pipeline. | |
| The input MUST be a detailed text description for the image to generate. | |
| """ | |
| if pipe is None: | |
| return "Error: Image generation pipeline failed to load. Please check Space logs during startup." | |
| print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---") | |
| try: | |
| with torch.no_grad(): | |
| pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0] | |
| # Save the PIL image to a temporary file, Gradio will handle displaying this path | |
| # NOTE: LangChain tools typically return strings. For image display, we'll return | |
| # the path, and handle its display in the Gradio UI directly based on content. | |
| temp_dir = tempfile.mkdtemp() | |
| image_path = os.path.join(temp_dir, "generated_image.png") | |
| pil_image.save(image_path) | |
| print(f"Image saved to temporary path: {image_path}") | |
| # Return a special string prefix so Gradio knows it's an image path | |
| return f"__IMAGE_PATH__:{image_path}" | |
| except Exception as e: | |
| print("Error in image_generator tool execution:") | |
| traceback.print_exc() | |
| return f"Error generating image: {str(e)}" | |
| # --- 3. Define other Tools for LangChain --- | |
| search = DuckDuckGoSearchRun() | |
| # --- 4. Define the LangChain Agent --- | |
| # Ensure models are loaded successfully before proceeding | |
| if pipe is None: | |
| raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.") | |
| # Instantiate the LLM for the agent | |
| llm = HuggingFaceHub( | |
| repo_id="HuggingFaceH4/zephyr-7b-beta", | |
| huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM | |
| model_kwargs={"temperature": 0.5, "max_new_tokens": 512} | |
| ) | |
| # Create the tools list | |
| tools = [image_generator, search] | |
| # Define the agent prompt | |
| # This prompt guides the LLM on how to use the tools | |
| prompt_template = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", """You are a powerful AI assistant that can generate images and search the web. | |
| You have access to the following tools: {tools} | |
| Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT. | |
| When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string. | |
| When you need factual information or context, use the `search` tool. | |
| Always follow these steps: | |
| 1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image. | |
| 2. If you need to search, use the `search` tool. | |
| 3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search. | |
| 4. When you have enough information, use the `image_generator` tool. | |
| 5. Provide your final answer. If you generated an image, include the image in your final answer. | |
| """), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present | |
| ] | |
| ) | |
| # Create the agent | |
| agent = create_react_agent(llm, tools, prompt_template) | |
| # Create the agent executor | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) | |
| # --- 5. Gradio UI Integration --- | |
| # Function to run the agent and display output | |
| def run_agent_in_gradio(message, history): | |
| # Convert Gradio history to LangChain chat_history format | |
| chat_history = [] | |
| for human_msg, ai_msg in history: | |
| chat_history.append(HumanMessage(content=human_msg)) | |
| chat_history.append(AIMessage(content=ai_msg)) | |
| try: | |
| # THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": [] | |
| response = agent_executor.invoke( | |
| {"input": message, "chat_history": chat_history, "agent_scratchpad": []} | |
| ) | |
| agent_output = response["output"] | |
| # Check if the output is an image path from our custom tool | |
| if agent_output.startswith("__IMAGE_PATH__:") : | |
| image_path = agent_output.replace("__IMAGE_PATH__:", "") | |
| # Return the Gradio Image component directly | |
| return gr.Image(value=image_path, label="Generated Image") | |
| else: | |
| # Return regular text | |
| return agent_output | |
| except Exception as e: | |
| print(f"Error running agent: {e}") | |
| traceback.print_exc() | |
| return f"β Agent encountered an error: {str(e)}" | |
| # Gradio ChatInterface setup | |
| demo = gr.ChatInterface( | |
| fn=run_agent_in_gradio, | |
| chatbot=gr.Chatbot(label="AI Agent"), | |
| textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7), | |
| title="Intelligent Image Generator & Web Search Agent (LangChain)", | |
| description="This agent can generate images based on prompts or search the web for information first." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |