File size: 3,991 Bytes
3b2a4e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Image generation module using Hugging Face Inference API."""

import os
from typing import Optional
from dotenv import load_dotenv
from huggingface_hub import InferenceClient

load_dotenv()


class ImageGenerator:
    """Generate images using Hugging Face Inference API."""
    
    def __init__(self, model: str = "black-forest-labs/FLUX.1-schnell"):
        """
        Initialize the image generator.
        
        Args:
            model: Hugging Face model ID for image generation
        """
        self.api_token = os.getenv("HF_TOKEN")
        self.model = model
        
        if not self.api_token:
            raise ValueError("HF_TOKEN not found in environment variables")
        
        # Initialize Hugging Face Inference Client
        self.client = InferenceClient(token=self.api_token)
    
    def generate_image(self, prompt: str) -> Optional[bytes]:
        """
        Generate an image from a text prompt.
        
        Args:
            prompt: Text description of the image to generate
            
        Returns:
            Image bytes if successful, None otherwise
        """
        # Enhance prompt for child-friendly, educational content
        enhanced_prompt = self._enhance_prompt_for_children(prompt)
        
        try:
            # Use the text_to_image method from InferenceClient
            image = self.client.text_to_image(
                enhanced_prompt,
                model=self.model
            )
            
            # Convert PIL Image to bytes
            from io import BytesIO
            img_byte_arr = BytesIO()
            image.save(img_byte_arr, format='PNG')
            img_byte_arr.seek(0)
            
            return img_byte_arr.read()
                
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return None
    
    def _enhance_prompt_for_children(self, prompt: str) -> str:
        """
        Enhance the prompt to ensure child-friendly, educational images.
        
        Args:
            prompt: Original prompt
            
        Returns:
            Enhanced prompt
        """
        # Add style modifiers for child-friendly content
        enhancements = [
            "child-friendly",
            "colorful",
            "educational illustration",
            "cartoon style",
            "bright and cheerful"
        ]
        
        # Combine original prompt with enhancements
        enhanced = f"{prompt}, {', '.join(enhancements)}"
        
        return enhanced
    
    def detect_image_request(self, message: str) -> Optional[str]:
        """
        Detect if a message contains an image request and extract the subject.
        
        Args:
            message: User's message
            
        Returns:
            Subject to generate image for, or None if no request detected
        """
        message_lower = message.lower()
        
        # Keywords that indicate image request
        image_keywords = [
            "show me", "muéstrame", "muestra",
            "draw", "dibuja", "dibujar",
            "picture of", "imagen de", "foto de",
            "what does", "cómo es", "como es",
            "i want to see", "quiero ver",
            "can you show", "puedes mostrar"
        ]
        
        # Check if message contains image request keywords
        for keyword in image_keywords:
            if keyword in message_lower:
                # Extract subject (simplified - could be improved with NLP)
                # Remove the keyword and get the remaining text
                subject = message_lower.replace(keyword, "").strip()
                # Remove common words
                subject = subject.replace("a ", "").replace("an ", "").replace("the ", "")
                subject = subject.replace("un ", "").replace("una ", "").replace("el ", "").replace("la ", "")
                
                if subject:
                    return subject
        
        return None