File size: 4,380 Bytes
55ed626
4e8e262
06e799e
 
 
 
 
4e8e262
06e799e
4e8e262
 
06e799e
 
4e8e262
06e799e
 
 
4e8e262
 
06e799e
4e8e262
 
 
 
 
06e799e
4e8e262
 
 
 
55ed626
 
 
 
 
06e799e
4e8e262
06e799e
 
 
 
4e8e262
06e799e
4e8e262
06e799e
4e8e262
06e799e
 
 
4e8e262
06e799e
4e8e262
 
06e799e
 
 
 
 
4e8e262
06e799e
 
4e8e262
 
 
 
 
06e799e
4e8e262
06e799e
 
4e8e262
 
06e799e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8e262
06e799e
 
4e8e262
06e799e
4e8e262
 
06e799e
 
 
4e8e262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e799e
 
4e8e262
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from flask import Flask, jsonify, request, send_file, render_template
from flask_cors import CORS
import numpy as np
from keras.models import load_model
from PIL import Image
import io

app = Flask(__name__)

# Enable CORS for all routes
CORS(app)

# Global variables for model
MODEL_PATH = "./models/face-gen-gan/generator_model_100.h5"
model = None
latent_dim = None

def load_gan_model():
    """Load the GAN model"""
    global model, latent_dim
    if model is None:
        print(f"Loading face generation GAN model from {MODEL_PATH}...")
        model = load_model(MODEL_PATH)
        latent_dim = model.input_shape[1]
        print(f"Model loaded successfully! Latent dimension: {latent_dim}")

# Load model on startup
load_gan_model()

@app.route("/")
def index():
    """Serve the web interface"""
    return render_template('index.html')

@app.route("/api")
def root():
    return jsonify({
        "message": "Face Generator API",
        "status": "running",
        "model": "face-gen-gan",
        "latent_dim": latent_dim
    })

@app.route("/health")
def health():
    return jsonify({
        "status": "healthy",
        "model_loaded": model is not None,
        "latent_dim": latent_dim
    })

@app.route("/generate", methods=["POST"])
def generate_faces():
    """
    Generate face images using the GAN model
    Returns a PNG image (single face or grid of faces)
    """
    if model is None:
        return jsonify({"error": "Model not loaded"}), 500
    
    try:
        # Get request data
        data = request.get_json() or {}
        n_samples = data.get("n_samples", 1)
        seed = data.get("seed", None)
        
        # Validate n_samples
        n_samples = max(1, min(int(n_samples), 16))  # Limit to 1-16
        
        # Set seed if provided
        if seed is not None:
            np.random.seed(int(seed))
        
        # Generate random latent points
        latent_points = np.random.randn(n_samples, latent_dim)
        
        # Generate images
        generated_images = model.predict(latent_points, verbose=0)
        
        # Scale from [-1, 1] to [0, 255]
        generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
        
        if n_samples == 1:
            # Single image
            img = Image.fromarray(generated_images[0])
        else:
            # Create a grid
            grid_size = int(np.ceil(np.sqrt(n_samples)))
            img_height, img_width = generated_images.shape[1:3]
            
            # Create blank canvas
            grid_img = np.ones((grid_size * img_height, grid_size * img_width, 3), dtype=np.uint8) * 255
            
            # Fill grid with generated images
            for i in range(n_samples):
                row = i // grid_size
                col = i % grid_size
                grid_img[row*img_height:(row+1)*img_height, 
                        col*img_width:(col+1)*img_width] = generated_images[i]
            
            img = Image.fromarray(grid_img)
        
        # Convert to bytes
        buf = io.BytesIO()
        img.save(buf, format='PNG')
        buf.seek(0)
        
        return send_file(buf, mimetype='image/png')
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route("/generate-single")
def generate_single_face():
    """
    Quick endpoint to generate a single face
    """
    seed = request.args.get('seed', None)
    
    if model is None:
        return jsonify({"error": "Model not loaded"}), 500
    
    try:
        # Set seed if provided
        if seed is not None:
            np.random.seed(int(seed))
        
        # Generate random latent points
        latent_points = np.random.randn(1, latent_dim)
        
        # Generate images
        generated_images = model.predict(latent_points, verbose=0)
        
        # Scale from [-1, 1] to [0, 255]
        generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
        
        # Single image
        img = Image.fromarray(generated_images[0])
        
        # Convert to bytes
        buf = io.BytesIO()
        img.save(buf, format='PNG')
        buf.seek(0)
        
        return send_file(buf, mimetype='image/png')
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8002, debug=False)