Sharris's picture
Add bias correction for age predictions - Model was over-predicting ages due to aggressive sample weighting - Apply correction formula to return realistic ages - Show both corrected and raw predictions for transparency
0ff3d67 verified
import os
import numpy as np
from PIL import Image
import tensorflow as tf
# Remove ResNet50 preprocessing - using simple normalization instead
import gradio as gr
# Load model from Hugging Face Hub
model = None
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age-detection-resnet50-model")
print(f"Attempting to load model from: {HF_MODEL_ID}")
try:
from huggingface_hub import hf_hub_download
print("Downloading best_model.h5...")
model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
print(f"Model downloaded to: {model_path}")
print("Loading model with TensorFlow...")
model = tf.keras.models.load_model(model_path, compile=False)
print(f"βœ… Successfully loaded model from {HF_MODEL_ID}")
except Exception as e:
print(f"❌ Failed to download best_model.h5: {e}")
# Fallback: try to download entire repo and look for model files
try:
print("Trying fallback: downloading entire repository...")
from huggingface_hub import snapshot_download
repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
print(f"Repository downloaded to: {repo_dir}")
# Look for model files in the downloaded repo
possible_files = ["best_model.h5", "final_model.h5", "model.h5"]
for filename in possible_files:
model_file = os.path.join(repo_dir, filename)
if os.path.exists(model_file):
print(f"Found model file: {model_file}")
try:
model = tf.keras.models.load_model(model_file, compile=False)
print(f"βœ… Successfully loaded model from {model_file}")
break
except Exception as load_error:
print(f"Failed to load {model_file}: {load_error}")
continue
if model is None:
# List all files in the repo for debugging
import os
print("Files in downloaded repository:")
for root, dirs, files in os.walk(repo_dir):
for file in files:
print(f" {os.path.join(root, file)}")
except Exception as e2:
print(f"❌ Fallback download also failed: {e2}")
if model is None:
raise RuntimeError(
f"❌ Could not load model from {HF_MODEL_ID}. Please ensure the repository contains a valid model file (best_model.h5, final_model.h5, or model.h5)."
)
INPUT_SIZE = (256, 256)
def predict_age(image: Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize(INPUT_SIZE)
arr = np.array(image).astype(np.float32)
# Use same normalization as training: [0,1] range instead of ResNet50 preprocessing
arr = arr / 255.0
arr = np.expand_dims(arr, 0)
pred = model.predict(arr)[0]
# Ensure scalar
if hasattr(pred, '__len__'):
pred = float(np.asarray(pred).squeeze())
else:
pred = float(pred)
# Apply bias correction - model is trained with aggressive sample weighting
# that causes it to predict too old ages consistently
raw_output = float(pred)
# Bias correction based on dataset analysis
if 50 <= pred <= 65:
# Model heavily biased toward 50-60 range, correct aggressively
corrected_age = pred * 0.6 - 10
elif pred > 65:
# Very old predictions, moderate correction
corrected_age = pred * 0.7 - 5
else:
# Already in reasonable range
corrected_age = pred
# Ensure reasonable bounds
corrected_age = max(1, min(100, corrected_age))
# Return two separate values to match the two gr.Number outputs
predicted_age = round(corrected_age, 1)
return predicted_age, raw_output
demo = gr.Interface(
fn=predict_age,
inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
outputs=[
gr.Number(label='Predicted age (years) - Bias Corrected'),
gr.Number(label='Raw model output (before correction)')
],
examples=[],
title='UTKFace Age Estimator - With Bias Correction',
description='Upload a cropped face image and the model will predict age in years. The model has been trained with sample weighting that causes age bias, so bias correction is applied to the final prediction. The raw output shows the uncorrected model prediction.'
)
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))