Sharris commited on
Commit
cb7980f
Β·
verified Β·
1 Parent(s): 4bfacc0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ from tensorflow.keras.applications.resnet50 import preprocess_input
6
+ import gradio as gr
7
+
8
+ # Load model from Hugging Face Hub
9
+ model = None
10
+ HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age-detection-resnet50-model")
11
+
12
+ print(f"Attempting to load model from: {HF_MODEL_ID}")
13
+
14
+ try:
15
+ from huggingface_hub import hf_hub_download
16
+ print("Downloading best_model.h5...")
17
+ model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
18
+ print(f"Model downloaded to: {model_path}")
19
+
20
+ print("Loading model with TensorFlow...")
21
+ model = tf.keras.models.load_model(model_path, compile=False)
22
+ print(f"βœ… Successfully loaded model from {HF_MODEL_ID}")
23
+
24
+ except Exception as e:
25
+ print(f"❌ Failed to download best_model.h5: {e}")
26
+
27
+ # Fallback: try to download entire repo and look for model files
28
+ try:
29
+ print("Trying fallback: downloading entire repository...")
30
+ from huggingface_hub import snapshot_download
31
+ repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
32
+ print(f"Repository downloaded to: {repo_dir}")
33
+
34
+ # Look for model files in the downloaded repo
35
+ possible_files = ["best_model.h5", "final_model.h5", "model.h5"]
36
+ for filename in possible_files:
37
+ model_file = os.path.join(repo_dir, filename)
38
+ if os.path.exists(model_file):
39
+ print(f"Found model file: {model_file}")
40
+ try:
41
+ model = tf.keras.models.load_model(model_file, compile=False)
42
+ print(f"βœ… Successfully loaded model from {model_file}")
43
+ break
44
+ except Exception as load_error:
45
+ print(f"Failed to load {model_file}: {load_error}")
46
+ continue
47
+
48
+ if model is None:
49
+ # List all files in the repo for debugging
50
+ import os
51
+ print("Files in downloaded repository:")
52
+ for root, dirs, files in os.walk(repo_dir):
53
+ for file in files:
54
+ print(f" {os.path.join(root, file)}")
55
+
56
+ except Exception as e2:
57
+ print(f"❌ Fallback download also failed: {e2}")
58
+
59
+ if model is None:
60
+ raise RuntimeError(
61
+ f"❌ Could not load model from {HF_MODEL_ID}. Please ensure the repository contains a valid model file (best_model.h5, final_model.h5, or model.h5)."
62
+ )
63
+
64
+ INPUT_SIZE = (256, 256)
65
+
66
+
67
+ def predict_age(image: Image.Image):
68
+ if image.mode != 'RGB':
69
+ image = image.convert('RGB')
70
+ image = image.resize(INPUT_SIZE)
71
+ arr = np.array(image).astype(np.float32)
72
+ arr = preprocess_input(arr)
73
+ arr = np.expand_dims(arr, 0)
74
+
75
+ pred = model.predict(arr)[0]
76
+ # Ensure scalar
77
+ if hasattr(pred, '__len__'):
78
+ pred = float(np.asarray(pred).squeeze())
79
+ else:
80
+ pred = float(pred)
81
+
82
+ return {
83
+ "predicted_age": round(pred, 2),
84
+ "raw_output": float(pred)
85
+ }
86
+
87
+
88
+ demo = gr.Interface(
89
+ fn=predict_age,
90
+ inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
91
+ outputs=[
92
+ gr.Number(label='Predicted age (years)'),
93
+ gr.Number(label='Raw model output')
94
+ ],
95
+ examples=[],
96
+ title='UTKFace Age Estimator',
97
+ description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.'
98
+ )
99
+
100
+ if __name__ == '__main__':
101
+ demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))