Spaces:
Running
Running
File size: 4,380 Bytes
55ed626 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 55ed626 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 06e799e 4e8e262 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from flask import Flask, jsonify, request, send_file, render_template
from flask_cors import CORS
import numpy as np
from keras.models import load_model
from PIL import Image
import io
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# Global variables for model
MODEL_PATH = "./models/face-gen-gan/generator_model_100.h5"
model = None
latent_dim = None
def load_gan_model():
"""Load the GAN model"""
global model, latent_dim
if model is None:
print(f"Loading face generation GAN model from {MODEL_PATH}...")
model = load_model(MODEL_PATH)
latent_dim = model.input_shape[1]
print(f"Model loaded successfully! Latent dimension: {latent_dim}")
# Load model on startup
load_gan_model()
@app.route("/")
def index():
"""Serve the web interface"""
return render_template('index.html')
@app.route("/api")
def root():
return jsonify({
"message": "Face Generator API",
"status": "running",
"model": "face-gen-gan",
"latent_dim": latent_dim
})
@app.route("/health")
def health():
return jsonify({
"status": "healthy",
"model_loaded": model is not None,
"latent_dim": latent_dim
})
@app.route("/generate", methods=["POST"])
def generate_faces():
"""
Generate face images using the GAN model
Returns a PNG image (single face or grid of faces)
"""
if model is None:
return jsonify({"error": "Model not loaded"}), 500
try:
# Get request data
data = request.get_json() or {}
n_samples = data.get("n_samples", 1)
seed = data.get("seed", None)
# Validate n_samples
n_samples = max(1, min(int(n_samples), 16)) # Limit to 1-16
# Set seed if provided
if seed is not None:
np.random.seed(int(seed))
# Generate random latent points
latent_points = np.random.randn(n_samples, latent_dim)
# Generate images
generated_images = model.predict(latent_points, verbose=0)
# Scale from [-1, 1] to [0, 255]
generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
if n_samples == 1:
# Single image
img = Image.fromarray(generated_images[0])
else:
# Create a grid
grid_size = int(np.ceil(np.sqrt(n_samples)))
img_height, img_width = generated_images.shape[1:3]
# Create blank canvas
grid_img = np.ones((grid_size * img_height, grid_size * img_width, 3), dtype=np.uint8) * 255
# Fill grid with generated images
for i in range(n_samples):
row = i // grid_size
col = i % grid_size
grid_img[row*img_height:(row+1)*img_height,
col*img_width:(col+1)*img_width] = generated_images[i]
img = Image.fromarray(grid_img)
# Convert to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return send_file(buf, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/generate-single")
def generate_single_face():
"""
Quick endpoint to generate a single face
"""
seed = request.args.get('seed', None)
if model is None:
return jsonify({"error": "Model not loaded"}), 500
try:
# Set seed if provided
if seed is not None:
np.random.seed(int(seed))
# Generate random latent points
latent_points = np.random.randn(1, latent_dim)
# Generate images
generated_images = model.predict(latent_points, verbose=0)
# Scale from [-1, 1] to [0, 255]
generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
# Single image
img = Image.fromarray(generated_images[0])
# Convert to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return send_file(buf, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8002, debug=False)
|