Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -8,28 +8,48 @@ from pydantic import BaseModel
|
|
| 8 |
from typing import Optional, Dict, Any
|
| 9 |
import json
|
| 10 |
import re
|
|
|
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
from NatureLM.models import NatureLM
|
| 13 |
from NatureLM.infer import Pipeline
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
)
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
model = None
|
| 28 |
pipeline = None
|
| 29 |
|
| 30 |
def load_model():
|
|
|
|
| 31 |
global model, pipeline
|
| 32 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Load NatureLM-audio model from HuggingFace
|
| 34 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 35 |
model = model.eval()
|
|
@@ -41,12 +61,41 @@ def load_model():
|
|
| 41 |
print("β
NatureLM model loaded successfully")
|
| 42 |
except Exception as e:
|
| 43 |
print(f"β Error loading model: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
raise e
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
class AnalysisResponse(BaseModel):
|
| 52 |
species: str
|
|
@@ -249,47 +298,50 @@ async def analyze_audio(file: UploadFile = File(...)):
|
|
| 249 |
# Combine results
|
| 250 |
combined_response = " ".join(results)
|
| 251 |
|
| 252 |
-
# Extract information
|
| 253 |
confidence_scores = extract_confidence_from_response(combined_response)
|
| 254 |
species_info = extract_species_info(combined_response)
|
| 255 |
|
| 256 |
-
#
|
| 257 |
-
overall_confidence = max(
|
| 258 |
-
confidence_scores["model_confidence"],
|
| 259 |
-
confidence_scores["llama_confidence"],
|
| 260 |
-
75.0 if species_info["common_name"] else 50.0 # Higher confidence if species identified
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
# Clean up temp file
|
| 264 |
os.remove(temp_path)
|
| 265 |
|
| 266 |
return AnalysisResponse(
|
| 267 |
species=species_info["common_name"] or "Unknown species",
|
| 268 |
interpretation=combined_response,
|
| 269 |
-
confidence=
|
| 270 |
-
signal_type=species_info["signal_type"] or "
|
| 271 |
common_name=species_info["common_name"] or "Unknown",
|
| 272 |
scientific_name=species_info["scientific_name"] or "Unknown",
|
| 273 |
-
habitat=species_info["habitat"] or "Unknown
|
| 274 |
-
behavior=species_info["behavior"] or "Unknown
|
| 275 |
audio_characteristics=audio_chars,
|
| 276 |
-
model_confidence=confidence_scores["model_confidence"],
|
| 277 |
-
llama_confidence=confidence_scores["llama_confidence"],
|
| 278 |
-
additional_insights=
|
| 279 |
-
cluster_group="NatureLM
|
| 280 |
)
|
| 281 |
|
| 282 |
except Exception as e:
|
| 283 |
-
# Clean up
|
| 284 |
-
if os.path.exists(temp_path):
|
| 285 |
os.remove(temp_path)
|
| 286 |
-
|
| 287 |
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 288 |
|
| 289 |
@app.get("/health")
|
| 290 |
async def health_check():
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
if __name__ == "__main__":
|
| 294 |
import uvicorn
|
| 295 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 8 |
from typing import Optional, Dict, Any
|
| 9 |
import json
|
| 10 |
import re
|
| 11 |
+
from contextlib import asynccontextmanager
|
| 12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 13 |
from NatureLM.models import NatureLM
|
| 14 |
from NatureLM.infer import Pipeline
|
| 15 |
|
| 16 |
+
# Set up cache directories BEFORE importing any HuggingFace modules
|
| 17 |
+
cache_base = "/app/.cache"
|
| 18 |
+
os.environ['HF_HOME'] = cache_base
|
| 19 |
+
os.environ['TRANSFORMERS_CACHE'] = f"{cache_base}/transformers"
|
| 20 |
+
os.environ['HF_DATASETS_CACHE'] = f"{cache_base}/datasets"
|
| 21 |
+
os.environ['HF_HUB_CACHE'] = f"{cache_base}/hub"
|
| 22 |
|
| 23 |
+
# Ensure cache directories exist with proper permissions
|
| 24 |
+
cache_dirs = [
|
| 25 |
+
cache_base,
|
| 26 |
+
os.environ['TRANSFORMERS_CACHE'],
|
| 27 |
+
os.environ['HF_DATASETS_CACHE'],
|
| 28 |
+
os.environ['HF_HUB_CACHE']
|
| 29 |
+
]
|
|
|
|
| 30 |
|
| 31 |
+
for cache_dir in cache_dirs:
|
| 32 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 33 |
+
# Ensure write permissions
|
| 34 |
+
os.chmod(cache_dir, 0o755)
|
| 35 |
+
|
| 36 |
+
# Initialize global variables
|
| 37 |
model = None
|
| 38 |
pipeline = None
|
| 39 |
|
| 40 |
def load_model():
|
| 41 |
+
"""Load the NatureLM model with proper error handling"""
|
| 42 |
global model, pipeline
|
| 43 |
try:
|
| 44 |
+
print("π Loading NatureLM model...")
|
| 45 |
+
print(f"π Using cache directory: {os.environ.get('HF_HOME', cache_base)}")
|
| 46 |
+
|
| 47 |
+
# Verify cache directories are writable
|
| 48 |
+
for cache_dir in cache_dirs:
|
| 49 |
+
if not os.access(cache_dir, os.W_OK):
|
| 50 |
+
raise PermissionError(f"Cache directory {cache_dir} is not writable")
|
| 51 |
+
print(f"β
Cache directory {cache_dir} is writable")
|
| 52 |
+
|
| 53 |
# Load NatureLM-audio model from HuggingFace
|
| 54 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 55 |
model = model.eval()
|
|
|
|
| 61 |
print("β
NatureLM model loaded successfully")
|
| 62 |
except Exception as e:
|
| 63 |
print(f"β Error loading model: {e}")
|
| 64 |
+
print(f"π Cache directory status:")
|
| 65 |
+
for cache_dir in cache_dirs:
|
| 66 |
+
exists = os.path.exists(cache_dir)
|
| 67 |
+
writable = os.access(cache_dir, os.W_OK) if exists else False
|
| 68 |
+
print(f" {cache_dir}: {'β
' if exists and writable else 'β'} (exists: {exists}, writable: {writable})")
|
| 69 |
raise e
|
| 70 |
|
| 71 |
+
@asynccontextmanager
|
| 72 |
+
async def lifespan(app: FastAPI):
|
| 73 |
+
"""Lifespan context manager for FastAPI"""
|
| 74 |
+
# Startup
|
| 75 |
+
print("π Starting up Animal Whisper AI Decoder...")
|
| 76 |
load_model()
|
| 77 |
+
print("β
Application startup complete")
|
| 78 |
+
|
| 79 |
+
yield
|
| 80 |
+
|
| 81 |
+
# Shutdown
|
| 82 |
+
print("π Shutting down Animal Whisper AI Decoder...")
|
| 83 |
+
|
| 84 |
+
app = FastAPI(
|
| 85 |
+
title="NatureLM Audio Analysis API",
|
| 86 |
+
description="AI-powered animal sound analysis using NatureLM",
|
| 87 |
+
version="1.0.0",
|
| 88 |
+
lifespan=lifespan
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# CORS middleware
|
| 92 |
+
app.add_middleware(
|
| 93 |
+
CORSMiddleware,
|
| 94 |
+
allow_origins=["*"],
|
| 95 |
+
allow_credentials=True,
|
| 96 |
+
allow_methods=["*"],
|
| 97 |
+
allow_headers=["*"],
|
| 98 |
+
)
|
| 99 |
|
| 100 |
class AnalysisResponse(BaseModel):
|
| 101 |
species: str
|
|
|
|
| 298 |
# Combine results
|
| 299 |
combined_response = " ".join(results)
|
| 300 |
|
| 301 |
+
# Extract information
|
| 302 |
confidence_scores = extract_confidence_from_response(combined_response)
|
| 303 |
species_info = extract_species_info(combined_response)
|
| 304 |
|
| 305 |
+
# Clean up temporary file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
os.remove(temp_path)
|
| 307 |
|
| 308 |
return AnalysisResponse(
|
| 309 |
species=species_info["common_name"] or "Unknown species",
|
| 310 |
interpretation=combined_response,
|
| 311 |
+
confidence=confidence_scores["model_confidence"] / 100.0,
|
| 312 |
+
signal_type=species_info["signal_type"] or "Unknown",
|
| 313 |
common_name=species_info["common_name"] or "Unknown",
|
| 314 |
scientific_name=species_info["scientific_name"] or "Unknown",
|
| 315 |
+
habitat=species_info["habitat"] or "Unknown",
|
| 316 |
+
behavior=species_info["behavior"] or "Unknown",
|
| 317 |
audio_characteristics=audio_chars,
|
| 318 |
+
model_confidence=confidence_scores["model_confidence"] / 100.0,
|
| 319 |
+
llama_confidence=confidence_scores["llama_confidence"] / 100.0,
|
| 320 |
+
additional_insights="Analysis completed successfully",
|
| 321 |
+
cluster_group="NatureLM"
|
| 322 |
)
|
| 323 |
|
| 324 |
except Exception as e:
|
| 325 |
+
# Clean up temporary file if it exists
|
| 326 |
+
if 'temp_path' in locals() and os.path.exists(temp_path):
|
| 327 |
os.remove(temp_path)
|
|
|
|
| 328 |
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 329 |
|
| 330 |
@app.get("/health")
|
| 331 |
async def health_check():
|
| 332 |
+
"""Health check endpoint"""
|
| 333 |
+
return {
|
| 334 |
+
"status": "healthy",
|
| 335 |
+
"model_loaded": model is not None,
|
| 336 |
+
"pipeline_ready": pipeline is not None,
|
| 337 |
+
"cache_directories": {
|
| 338 |
+
"hf_home": os.environ.get('HF_HOME'),
|
| 339 |
+
"transformers_cache": os.environ.get('TRANSFORMERS_CACHE'),
|
| 340 |
+
"datasets_cache": os.environ.get('HF_DATASETS_CACHE'),
|
| 341 |
+
"hub_cache": os.environ.get('HF_HUB_CACHE')
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
|
| 345 |
if __name__ == "__main__":
|
| 346 |
import uvicorn
|
| 347 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|