|
|
import os |
|
|
import tempfile |
|
|
import logging |
|
|
from typing import Tuple, Dict |
|
|
|
|
|
import gradio as gr |
|
|
from fastapi import FastAPI, UploadFile, File, Form, Header, HTTPException, Depends |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.testclient import TestClient |
|
|
import io |
|
|
from spaces import GPU |
|
|
from huggingface_hub import snapshot_download |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
try: |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, auth as fb_auth |
|
|
except Exception: |
|
|
firebase_admin = None |
|
|
credentials = None |
|
|
fb_auth = None |
|
|
|
|
|
FIREBASE_APP = None |
|
|
|
|
|
|
|
|
def _init_firebase_if_possible() -> None: |
|
|
global FIREBASE_APP |
|
|
if FIREBASE_APP is not None: |
|
|
return |
|
|
if firebase_admin is None: |
|
|
logger.info("firebase-admin not installed; skipping Firebase init") |
|
|
return |
|
|
|
|
|
sa_env = os.getenv("FIREBASE_CREDENTIALS_JSON", "").strip() |
|
|
sa_path = "firebase_service_account.json" |
|
|
try: |
|
|
cred_obj = None |
|
|
if sa_env: |
|
|
|
|
|
if os.path.exists(sa_env): |
|
|
cred_obj = credentials.Certificate(sa_env) |
|
|
else: |
|
|
cred_obj = credentials.Certificate(json.loads(sa_env)) |
|
|
elif os.path.exists(sa_path): |
|
|
cred_obj = credentials.Certificate(sa_path) |
|
|
if cred_obj is not None: |
|
|
FIREBASE_APP = firebase_admin.initialize_app(cred_obj) |
|
|
logger.info("Firebase initialized successfully") |
|
|
else: |
|
|
logger.info("No Firebase credentials provided; skipping Firebase init") |
|
|
except Exception as e: |
|
|
logger.warning("Firebase init failed: %s", e) |
|
|
FIREBASE_APP = None |
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") |
|
|
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0") |
|
|
|
|
|
from runners.simple_runner import SimpleRunner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("sfe-app") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RUNNER: SimpleRunner | None = None |
|
|
|
|
|
|
|
|
def ensure_weights(): |
|
|
"""Make sure pretrained weights exist locally; otherwise fetch from your HF model repo.""" |
|
|
need = [ |
|
|
"pretrained_models/sfe_editor_light.pt", |
|
|
"pretrained_models/stylegan2-ffhq-config-f.pt", |
|
|
"pretrained_models/e4e_ffhq_encode.pt", |
|
|
"pretrained_models/stylegan2-ffhq-config-f.pkl", |
|
|
"pretrained_models/shape_predictor_68_face_landmarks.dat", |
|
|
"pretrained_models/fs3.npy", |
|
|
"pretrained_models/delta_mapper.pt", |
|
|
"pretrained_models/iresnet50-7f187506.pth", |
|
|
"pretrained_models/model_ir_se50.pth", |
|
|
"pretrained_models/CurricularFace_Backbone.pth", |
|
|
"pretrained_models/face_parsing.farl.lapa.main_ema_136500_jit191.pt", |
|
|
"pretrained_models/mobilenet0.25_Final.pth", |
|
|
"pretrained_models/moco_v2_800ep_pretrain.pt", |
|
|
"pretrained_models/79999_iter.pth", |
|
|
] |
|
|
|
|
|
|
|
|
files_exist = any(os.path.exists(p) for p in need) |
|
|
if files_exist: |
|
|
logger.info("Some weights already exist, skipping download") |
|
|
return |
|
|
|
|
|
repo_id = "LogicGoInfotechSpaces/Smile_Changer_pre_model" |
|
|
logger.info("Missing weights; downloading snapshot from %s", repo_id) |
|
|
|
|
|
try: |
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=".", |
|
|
allow_patterns=["**/*"], |
|
|
token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN"), |
|
|
) |
|
|
logger.info("Download completed successfully") |
|
|
except Exception as e: |
|
|
logger.error("Download failed: %s", e) |
|
|
return |
|
|
|
|
|
|
|
|
import time |
|
|
time.sleep(3) |
|
|
|
|
|
|
|
|
if os.path.exists("pretrained_models"): |
|
|
logger.info("Files in pretrained_models directory:") |
|
|
try: |
|
|
for root, dirs, files in os.walk("pretrained_models"): |
|
|
for file in files: |
|
|
full_path = os.path.join(root, file) |
|
|
logger.info(" %s (size: %d bytes)", full_path, os.path.getsize(full_path)) |
|
|
except Exception as e: |
|
|
logger.error("Error listing files: %s", e) |
|
|
else: |
|
|
logger.error("pretrained_models directory does not exist!") |
|
|
|
|
|
|
|
|
for file_path in need: |
|
|
if not os.path.exists(file_path): |
|
|
logger.warning("File %s still not found after download", file_path) |
|
|
else: |
|
|
logger.info("File %s found successfully", file_path) |
|
|
|
|
|
|
|
|
def get_runner() -> SimpleRunner: |
|
|
global RUNNER |
|
|
if RUNNER is None: |
|
|
logger.info("Getting runner - calling ensure_weights()") |
|
|
ensure_weights() |
|
|
logger.info("Initializing SimpleRunner with %s", "pretrained_models/sfe_editor_light.pt") |
|
|
RUNNER = SimpleRunner( |
|
|
editor_ckpt_pth="pretrained_models/sfe_editor_light.pt", |
|
|
) |
|
|
logger.info("SimpleRunner initialized successfully") |
|
|
return RUNNER |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ATTRIBUTE_MAP: Dict[str, Tuple[str, Tuple[float, float]]] = { |
|
|
|
|
|
"Smile": ("fs_smiling", (-10.0, 10.0)), |
|
|
"Age": ("age", (-10.0, 10.0)), |
|
|
"Female features": ("gender", (-10.0, 7.0)), |
|
|
|
|
|
|
|
|
|
|
|
"Beard": ("trimmed_beard", (-30.0, 30.0)), |
|
|
|
|
|
"Mustache/Goatee": ("goatee", (-7.0, 7.0)), |
|
|
|
|
|
|
|
|
"Glasses": ("fs_glasses", (-20.0, 30.0)), |
|
|
"Makeup": ("fs_makeup", (-10.0, 15.0)), |
|
|
|
|
|
|
|
|
"Curly hair": ("curly_hair", (0.0, 0.12)), |
|
|
"Afro": ("afro", (0.0, 0.14)), |
|
|
|
|
|
|
|
|
|
|
|
"Orange hair (text)": ("styleclip_global_a face_a face with orange hair_0.18", (0.0, 0.2)), |
|
|
"Blonde hair (text)": ("styleclip_global_a face_a face with blonde hair_0.18", (0.0, 0.2)), |
|
|
} |
|
|
|
|
|
|
|
|
def recommended_range(attr_name: str) -> Tuple[float, float]: |
|
|
edit_name, rng = ATTRIBUTE_MAP[attr_name] |
|
|
return rng |
|
|
|
|
|
|
|
|
def run_edit( |
|
|
image: Image.Image, |
|
|
attribute: str, |
|
|
strength: float, |
|
|
align_face: bool, |
|
|
use_bg_mask: bool, |
|
|
custom_text_edit: str, |
|
|
) -> Image.Image: |
|
|
"""Run a single attribute edit and return the edited image.""" |
|
|
runner = get_runner() |
|
|
|
|
|
|
|
|
edit_name, (lo, hi) = ATTRIBUTE_MAP[attribute] |
|
|
if custom_text_edit and attribute.endswith("(text)"): |
|
|
|
|
|
if custom_text_edit.strip(): |
|
|
edit_name = custom_text_edit.strip() |
|
|
|
|
|
clipped_strength = max(lo, min(hi, strength)) |
|
|
if clipped_strength != strength: |
|
|
logger.info("Clipped strength from %s to %s for %s", strength, clipped_strength, attribute) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
inp_path = os.path.join(tmpdir, "input.jpg") |
|
|
out_path = os.path.join(tmpdir, "edited.jpg") |
|
|
image.convert("RGB").save(inp_path) |
|
|
|
|
|
logger.info("Editing %s with power %s", edit_name, clipped_strength) |
|
|
_ = runner.edit( |
|
|
orig_img_pth=inp_path, |
|
|
editing_name=edit_name, |
|
|
edited_power=clipped_strength, |
|
|
save_pth=out_path, |
|
|
align=align_face, |
|
|
use_mask=use_bg_mask, |
|
|
) |
|
|
|
|
|
return Image.open(out_path).convert("RGB") |
|
|
|
|
|
|
|
|
def build_ui() -> gr.Blocks: |
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
|
gr.Markdown(""" |
|
|
**StyleFeatureEditor – Facial Attribute Editing** |
|
|
Upload a face and apply edits like smile, age, beard, hair style/color, glasses, and makeup. |
|
|
|
|
|
**Tips:** |
|
|
- **Beard/Goatee**: Use **negative values** to ADD facial hair, positive values to remove |
|
|
- **Smile**: Positive values add smile, negative values remove smile |
|
|
- **Age**: Positive values make older, negative values make younger |
|
|
- **Glasses**: Positive values add glasses, negative values remove glasses |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
inp = gr.Image(type="pil", label="Input face", sources=["upload", "clipboard"]) |
|
|
attr = gr.Dropdown( |
|
|
choices=list(ATTRIBUTE_MAP.keys()), |
|
|
value="Smile", |
|
|
label="Attribute", |
|
|
) |
|
|
strength = gr.Slider(-15, 15, value=5, step=0.01, label="Strength (p)") |
|
|
align_face = gr.Checkbox(value=True, label="Align face before editing") |
|
|
use_bg_mask = gr.Checkbox(value=False, label="Use background mask (reduce artifacts)") |
|
|
custom_text = gr.Textbox( |
|
|
value="", |
|
|
label="Custom text edit (StyleCLIP Global Mapper)", |
|
|
placeholder="styleclip_global_a face_a face with black hair_0.18", |
|
|
) |
|
|
run_btn = gr.Button("Run edit") |
|
|
|
|
|
with gr.Column(): |
|
|
out = gr.Image(type="pil", label="Edited output") |
|
|
|
|
|
|
|
|
def _on_attr_change(name: str): |
|
|
lo, hi = recommended_range(name) |
|
|
|
|
|
new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0)) |
|
|
return gr.Slider(minimum=lo, maximum=hi, value=new_val) |
|
|
|
|
|
attr.change(_on_attr_change, inputs=attr, outputs=strength) |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_edit, |
|
|
inputs=[inp, attr, strength, align_face, use_bg_mask, custom_text], |
|
|
outputs=out, |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
demo = build_ui() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api = FastAPI(title="Smile Changer API") |
|
|
|
|
|
|
|
|
def _require_auth(authorization: str | None = Header(default=None)): |
|
|
"""Accepts either a static Bearer token (API_AUTH_TOKEN) or a Firebase ID token. |
|
|
Returns a dict of auth info if authenticated; raises 401 otherwise. |
|
|
""" |
|
|
expected = os.getenv("API_AUTH_TOKEN", "logicgo_123") |
|
|
if not authorization or not authorization.startswith("Bearer "): |
|
|
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") |
|
|
token = authorization.split(" ", 1)[1] |
|
|
|
|
|
|
|
|
if token == expected: |
|
|
return {"auth": "static"} |
|
|
|
|
|
|
|
|
_init_firebase_if_possible() |
|
|
if firebase_admin is not None and fb_auth is not None and FIREBASE_APP is not None: |
|
|
try: |
|
|
claims = fb_auth.verify_id_token(token) |
|
|
return {"auth": "firebase", "claims": claims, "uid": claims.get("uid")} |
|
|
except Exception as e: |
|
|
logger.warning("Firebase token verification failed: %s", e) |
|
|
|
|
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
|
|
|
|
|
|
@api.get("/") |
|
|
def root_index(): |
|
|
return { |
|
|
"name": "Smile Changer API", |
|
|
"status": "ok", |
|
|
"ui": "/app", |
|
|
"endpoints": { |
|
|
"GET /health": "public health", |
|
|
"GET /api/health": "public health (alias)", |
|
|
"GET /api/ping": "auth check", |
|
|
"GET /api/attributes": "list attributes", |
|
|
"POST /api/edit": "generic edit", |
|
|
"POST /api/edit/{attribute}": "edit by attribute name", |
|
|
}, |
|
|
"auth": "set API_AUTH_TOKEN to require Authorization: Bearer <token> (except /health)", |
|
|
} |
|
|
|
|
|
|
|
|
@api.get("/health") |
|
|
def health_root(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
|
|
|
@api.get("/api/attributes") |
|
|
def list_attributes(_: None = Depends(_require_auth)): |
|
|
items = {} |
|
|
for k, v in ATTRIBUTE_MAP.items(): |
|
|
edit_name, (lo, hi) = v |
|
|
items[k] = {"internal": edit_name, "min": lo, "max": hi} |
|
|
return JSONResponse(items) |
|
|
|
|
|
|
|
|
@api.get("/api/health") |
|
|
def health(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
|
|
|
@api.get("/api/ping") |
|
|
def ping(_: None = Depends(_require_auth)): |
|
|
return {"status": "ok", "auth": True} |
|
|
|
|
|
|
|
|
@api.get("/api/me") |
|
|
def me(user=Depends(_require_auth)): |
|
|
|
|
|
info = {"mode": user.get("auth")} |
|
|
if user.get("auth") == "firebase": |
|
|
info["uid"] = user.get("uid") |
|
|
|
|
|
claims = user.get("claims", {}) |
|
|
basic = {k: claims.get(k) for k in ("email", "name", "picture", "user_id", "uid") if claims.get(k) is not None} |
|
|
info["claims"] = basic |
|
|
return JSONResponse(info) |
|
|
|
|
|
|
|
|
@api.on_event("startup") |
|
|
def _self_check(): |
|
|
try: |
|
|
client = TestClient(api) |
|
|
r = client.get("/api/health") |
|
|
logger.info("Self-check /api/health -> %s %s", r.status_code, r.json() if r.headers.get("content-type"," ").startswith("application/json") else "") |
|
|
except Exception as e: |
|
|
logger.error("Self-check failed: %s", e) |
|
|
|
|
|
|
|
|
@api.post("/api/edit") |
|
|
async def api_edit( |
|
|
file: UploadFile = File(...), |
|
|
attribute: str = Form(...), |
|
|
strength: float = Form(5.0), |
|
|
align_face: bool = Form(True), |
|
|
use_bg_mask: bool = Form(False), |
|
|
custom_text_edit: str = Form(""), |
|
|
_: None = Depends(_require_auth) |
|
|
): |
|
|
data = await file.read() |
|
|
image = Image.open(io.BytesIO(data)).convert("RGB") |
|
|
result = run_edit( |
|
|
image=image, |
|
|
attribute=attribute, |
|
|
strength=strength, |
|
|
align_face=align_face, |
|
|
use_bg_mask=use_bg_mask, |
|
|
custom_text_edit=custom_text_edit, |
|
|
) |
|
|
buf = io.BytesIO() |
|
|
result.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
return StreamingResponse(buf, media_type="image/png") |
|
|
|
|
|
|
|
|
@api.post("/api/edit/{attribute_name}") |
|
|
async def api_edit_by_attribute( |
|
|
attribute_name: str, |
|
|
file: UploadFile = File(...), |
|
|
strength: float = Form(5.0), |
|
|
align_face: bool = Form(True), |
|
|
use_bg_mask: bool = Form(False), |
|
|
custom_text_edit: str = Form(""), |
|
|
_: None = Depends(_require_auth) |
|
|
): |
|
|
return await api_edit( |
|
|
file=file, |
|
|
attribute=attribute_name, |
|
|
strength=strength, |
|
|
align_face=align_face, |
|
|
use_bg_mask=use_bg_mask, |
|
|
custom_text_edit=custom_text_edit, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _register_attribute_endpoint(path: str, attribute_value: str): |
|
|
@api.post(path) |
|
|
async def _endpoint( |
|
|
file: UploadFile = File(...), |
|
|
strength: float = Form(5.0), |
|
|
align_face: bool = Form(True), |
|
|
use_bg_mask: bool = Form(False), |
|
|
custom_text_edit: str = Form(""), |
|
|
_: None = Depends(_require_auth) |
|
|
): |
|
|
return await api_edit( |
|
|
file=file, |
|
|
attribute=attribute_value, |
|
|
strength=strength, |
|
|
align_face=align_face, |
|
|
use_bg_mask=use_bg_mask, |
|
|
custom_text_edit=custom_text_edit, |
|
|
) |
|
|
|
|
|
|
|
|
_register_attribute_endpoint("/api/smile", "Smile") |
|
|
_register_attribute_endpoint("/api/age", "Age") |
|
|
_register_attribute_endpoint("/api/female-features", "Female features") |
|
|
_register_attribute_endpoint("/api/beard", "Beard") |
|
|
_register_attribute_endpoint("/api/mustache-goatee", "Mustache/Goatee") |
|
|
_register_attribute_endpoint("/api/glasses", "Glasses") |
|
|
_register_attribute_endpoint("/api/makeup", "Makeup") |
|
|
_register_attribute_endpoint("/api/curly-hair", "Curly hair") |
|
|
_register_attribute_endpoint("/api/afro", "Afro") |
|
|
_register_attribute_endpoint("/api/orange-hair-text", "Orange hair (text)") |
|
|
_register_attribute_endpoint("/api/blonde-hair-text", "Blonde hair (text)") |
|
|
|
|
|
|
|
|
@api.post("/api/image-edit") |
|
|
async def api_image_edit( |
|
|
file: UploadFile = File(...), |
|
|
attribute: str = Form("Smile"), |
|
|
strength: float = Form(5.0), |
|
|
align_face: bool = Form(False), |
|
|
use_bg_mask: bool = Form(False), |
|
|
custom_text_edit: str = Form("") |
|
|
): |
|
|
data = await file.read() |
|
|
image = Image.open(io.BytesIO(data)).convert("RGB") |
|
|
|
|
|
result = run_edit( |
|
|
image=image, |
|
|
attribute=attribute, |
|
|
strength=strength, |
|
|
align_face=align_face, |
|
|
use_bg_mask=use_bg_mask, |
|
|
custom_text_edit=custom_text_edit |
|
|
) |
|
|
buf = io.BytesIO() |
|
|
result.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
return StreamingResponse(buf, media_type="image/png") |
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(api, demo, path="/app") |
|
|
|
|
|
|
|
|
@GPU() |
|
|
def _warmup_gpu(): |
|
|
|
|
|
return "ok" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
try: |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
except Exception as e: |
|
|
print("Failed to start uvicorn:", e) |
|
|
|
|
|
|
|
|
|