tscr-369 commited on
Commit
c717ec4
Β·
verified Β·
1 Parent(s): c60a7c6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +87 -35
main.py CHANGED
@@ -8,28 +8,48 @@ from pydantic import BaseModel
8
  from typing import Optional, Dict, Any
9
  import json
10
  import re
 
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  from NatureLM.models import NatureLM
13
  from NatureLM.infer import Pipeline
14
 
15
- app = FastAPI(title="NatureLM Audio Analysis API")
 
 
 
 
 
16
 
17
- # CORS middleware
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"],
21
- allow_credentials=True,
22
- allow_methods=["*"],
23
- allow_headers=["*"],
24
- )
25
 
26
- # Initialize NatureLM model
 
 
 
 
 
27
  model = None
28
  pipeline = None
29
 
30
  def load_model():
 
31
  global model, pipeline
32
  try:
 
 
 
 
 
 
 
 
 
33
  # Load NatureLM-audio model from HuggingFace
34
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
35
  model = model.eval()
@@ -41,12 +61,41 @@ def load_model():
41
  print("βœ… NatureLM model loaded successfully")
42
  except Exception as e:
43
  print(f"❌ Error loading model: {e}")
 
 
 
 
 
44
  raise e
45
 
46
- # Load model on startup
47
- @app.on_event("startup")
48
- async def startup_event():
 
 
49
  load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  class AnalysisResponse(BaseModel):
52
  species: str
@@ -249,47 +298,50 @@ async def analyze_audio(file: UploadFile = File(...)):
249
  # Combine results
250
  combined_response = " ".join(results)
251
 
252
- # Extract information from response
253
  confidence_scores = extract_confidence_from_response(combined_response)
254
  species_info = extract_species_info(combined_response)
255
 
256
- # Calculate overall confidence based on response quality
257
- overall_confidence = max(
258
- confidence_scores["model_confidence"],
259
- confidence_scores["llama_confidence"],
260
- 75.0 if species_info["common_name"] else 50.0 # Higher confidence if species identified
261
- )
262
-
263
- # Clean up temp file
264
  os.remove(temp_path)
265
 
266
  return AnalysisResponse(
267
  species=species_info["common_name"] or "Unknown species",
268
  interpretation=combined_response,
269
- confidence=overall_confidence,
270
- signal_type=species_info["signal_type"] or "Vocalization",
271
  common_name=species_info["common_name"] or "Unknown",
272
  scientific_name=species_info["scientific_name"] or "Unknown",
273
- habitat=species_info["habitat"] or "Unknown habitat",
274
- behavior=species_info["behavior"] or "Unknown behavior",
275
  audio_characteristics=audio_chars,
276
- model_confidence=confidence_scores["model_confidence"],
277
- llama_confidence=confidence_scores["llama_confidence"],
278
- additional_insights=combined_response,
279
- cluster_group="NatureLM Analysis"
280
  )
281
 
282
  except Exception as e:
283
- # Clean up temp file if it exists
284
- if os.path.exists(temp_path):
285
  os.remove(temp_path)
286
-
287
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
288
 
289
  @app.get("/health")
290
  async def health_check():
291
- return {"status": "healthy", "model_loaded": model is not None}
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  if __name__ == "__main__":
294
  import uvicorn
295
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
8
  from typing import Optional, Dict, Any
9
  import json
10
  import re
11
+ from contextlib import asynccontextmanager
12
  from transformers import AutoTokenizer, AutoModelForCausalLM
13
  from NatureLM.models import NatureLM
14
  from NatureLM.infer import Pipeline
15
 
16
+ # Set up cache directories BEFORE importing any HuggingFace modules
17
+ cache_base = "/app/.cache"
18
+ os.environ['HF_HOME'] = cache_base
19
+ os.environ['TRANSFORMERS_CACHE'] = f"{cache_base}/transformers"
20
+ os.environ['HF_DATASETS_CACHE'] = f"{cache_base}/datasets"
21
+ os.environ['HF_HUB_CACHE'] = f"{cache_base}/hub"
22
 
23
+ # Ensure cache directories exist with proper permissions
24
+ cache_dirs = [
25
+ cache_base,
26
+ os.environ['TRANSFORMERS_CACHE'],
27
+ os.environ['HF_DATASETS_CACHE'],
28
+ os.environ['HF_HUB_CACHE']
29
+ ]
 
30
 
31
+ for cache_dir in cache_dirs:
32
+ os.makedirs(cache_dir, exist_ok=True)
33
+ # Ensure write permissions
34
+ os.chmod(cache_dir, 0o755)
35
+
36
+ # Initialize global variables
37
  model = None
38
  pipeline = None
39
 
40
  def load_model():
41
+ """Load the NatureLM model with proper error handling"""
42
  global model, pipeline
43
  try:
44
+ print("πŸ”„ Loading NatureLM model...")
45
+ print(f"πŸ“ Using cache directory: {os.environ.get('HF_HOME', cache_base)}")
46
+
47
+ # Verify cache directories are writable
48
+ for cache_dir in cache_dirs:
49
+ if not os.access(cache_dir, os.W_OK):
50
+ raise PermissionError(f"Cache directory {cache_dir} is not writable")
51
+ print(f"βœ… Cache directory {cache_dir} is writable")
52
+
53
  # Load NatureLM-audio model from HuggingFace
54
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
55
  model = model.eval()
 
61
  print("βœ… NatureLM model loaded successfully")
62
  except Exception as e:
63
  print(f"❌ Error loading model: {e}")
64
+ print(f"πŸ” Cache directory status:")
65
+ for cache_dir in cache_dirs:
66
+ exists = os.path.exists(cache_dir)
67
+ writable = os.access(cache_dir, os.W_OK) if exists else False
68
+ print(f" {cache_dir}: {'βœ…' if exists and writable else '❌'} (exists: {exists}, writable: {writable})")
69
  raise e
70
 
71
+ @asynccontextmanager
72
+ async def lifespan(app: FastAPI):
73
+ """Lifespan context manager for FastAPI"""
74
+ # Startup
75
+ print("πŸš€ Starting up Animal Whisper AI Decoder...")
76
  load_model()
77
+ print("βœ… Application startup complete")
78
+
79
+ yield
80
+
81
+ # Shutdown
82
+ print("πŸ›‘ Shutting down Animal Whisper AI Decoder...")
83
+
84
+ app = FastAPI(
85
+ title="NatureLM Audio Analysis API",
86
+ description="AI-powered animal sound analysis using NatureLM",
87
+ version="1.0.0",
88
+ lifespan=lifespan
89
+ )
90
+
91
+ # CORS middleware
92
+ app.add_middleware(
93
+ CORSMiddleware,
94
+ allow_origins=["*"],
95
+ allow_credentials=True,
96
+ allow_methods=["*"],
97
+ allow_headers=["*"],
98
+ )
99
 
100
  class AnalysisResponse(BaseModel):
101
  species: str
 
298
  # Combine results
299
  combined_response = " ".join(results)
300
 
301
+ # Extract information
302
  confidence_scores = extract_confidence_from_response(combined_response)
303
  species_info = extract_species_info(combined_response)
304
 
305
+ # Clean up temporary file
 
 
 
 
 
 
 
306
  os.remove(temp_path)
307
 
308
  return AnalysisResponse(
309
  species=species_info["common_name"] or "Unknown species",
310
  interpretation=combined_response,
311
+ confidence=confidence_scores["model_confidence"] / 100.0,
312
+ signal_type=species_info["signal_type"] or "Unknown",
313
  common_name=species_info["common_name"] or "Unknown",
314
  scientific_name=species_info["scientific_name"] or "Unknown",
315
+ habitat=species_info["habitat"] or "Unknown",
316
+ behavior=species_info["behavior"] or "Unknown",
317
  audio_characteristics=audio_chars,
318
+ model_confidence=confidence_scores["model_confidence"] / 100.0,
319
+ llama_confidence=confidence_scores["llama_confidence"] / 100.0,
320
+ additional_insights="Analysis completed successfully",
321
+ cluster_group="NatureLM"
322
  )
323
 
324
  except Exception as e:
325
+ # Clean up temporary file if it exists
326
+ if 'temp_path' in locals() and os.path.exists(temp_path):
327
  os.remove(temp_path)
 
328
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
329
 
330
  @app.get("/health")
331
  async def health_check():
332
+ """Health check endpoint"""
333
+ return {
334
+ "status": "healthy",
335
+ "model_loaded": model is not None,
336
+ "pipeline_ready": pipeline is not None,
337
+ "cache_directories": {
338
+ "hf_home": os.environ.get('HF_HOME'),
339
+ "transformers_cache": os.environ.get('TRANSFORMERS_CACHE'),
340
+ "datasets_cache": os.environ.get('HF_DATASETS_CACHE'),
341
+ "hub_cache": os.environ.get('HF_HUB_CACHE')
342
+ }
343
+ }
344
 
345
  if __name__ == "__main__":
346
  import uvicorn
347
+ uvicorn.run(app, host="0.0.0.0", port=8000)