tscr-369 commited on
Commit
83a2db2
·
verified ·
1 Parent(s): 0229d5e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +298 -25
main.py CHANGED
@@ -1,17 +1,19 @@
1
  import os
2
- from fastapi import FastAPI, File, UploadFile
 
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from huggingface_hub import login
5
- from NatureLM.models import NatureLM
6
- from NatureLM.infer import Pipeline
7
- import tempfile
 
 
8
 
9
- # Authenticate with HuggingFace to access gated models
10
- login(token=os.environ.get("HF_TOKEN"))
11
 
12
- app = FastAPI()
13
-
14
- # Allow CORS for all origins (for frontend integration)
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -20,20 +22,291 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- # Load the model once at startup
24
- model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio").eval()
25
- pipeline = Pipeline(model=model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- @app.post("/analyze")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  async def analyze_audio(file: UploadFile = File(...)):
29
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
30
- tmp.write(await file.read())
31
- tmp_path = tmp.name
32
- results = pipeline([tmp_path], ["What is the common name for the focal species in the audio? Answer:"])
33
- return {
34
- "species": results[0], # Adjust parsing as needed
35
- "interpretation": "TODO: parse from model output",
36
- "confidence": 90, # TODO: parse or estimate
37
- "clusterGroup": "TODO: parse from model output",
38
- "additionalInfo": "TODO: parse from model output"
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from typing import Optional, Dict, Any
9
+ import json
10
+ import re
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from naturelm_audio import NatureLMAudio
13
 
14
+ app = FastAPI(title="NatureLM Audio Analysis API")
 
15
 
16
+ # CORS middleware
 
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Initialize NatureLM model
26
+ model = None
27
+ tokenizer = None
28
+
29
+ def load_model():
30
+ global model, tokenizer
31
+ try:
32
+ # Load NatureLM-audio model
33
+ model = NatureLMAudio.from_pretrained("NatureLM/NatureLM-audio")
34
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
35
+
36
+ # Set model to evaluation mode
37
+ model.eval()
38
+ if torch.cuda.is_available():
39
+ model = model.cuda()
40
+ print("✅ NatureLM model loaded successfully")
41
+ except Exception as e:
42
+ print(f"❌ Error loading model: {e}")
43
+ raise e
44
+
45
+ # Load model on startup
46
+ @app.on_event("startup")
47
+ async def startup_event():
48
+ load_model()
49
+
50
+ class AnalysisResponse(BaseModel):
51
+ species: str
52
+ interpretation: str
53
+ confidence: float
54
+ signal_type: str
55
+ common_name: str
56
+ scientific_name: str
57
+ habitat: str
58
+ behavior: str
59
+ audio_characteristics: Dict[str, Any]
60
+ model_confidence: float
61
+ llama_confidence: float
62
+ additional_insights: str
63
+ cluster_group: str
64
+
65
+ def extract_confidence_from_response(response_text: str) -> Dict[str, float]:
66
+ """Extract confidence scores from NatureLM response"""
67
+ confidence_scores = {
68
+ "model_confidence": 0.0,
69
+ "llama_confidence": 0.0
70
+ }
71
+
72
+ # Look for confidence patterns in the response
73
+ confidence_patterns = [
74
+ r"confidence[:\s]*(\d+(?:\.\d+)?)",
75
+ r"certainty[:\s]*(\d+(?:\.\d+)?)",
76
+ r"(\d+(?:\.\d+)?)%?\s*confidence",
77
+ r"confidence\s*level[:\s]*(\d+(?:\.\d+)?)"
78
+ ]
79
+
80
+ for pattern in confidence_patterns:
81
+ matches = re.findall(pattern, response_text.lower())
82
+ if matches:
83
+ try:
84
+ confidence_scores["model_confidence"] = float(matches[0])
85
+ confidence_scores["llama_confidence"] = float(matches[0])
86
+ break
87
+ except ValueError:
88
+ continue
89
+
90
+ return confidence_scores
91
 
92
+ def extract_species_info(response_text: str) -> Dict[str, str]:
93
+ """Extract detailed species information from NatureLM response"""
94
+ info = {
95
+ "common_name": "",
96
+ "scientific_name": "",
97
+ "habitat": "",
98
+ "behavior": "",
99
+ "signal_type": ""
100
+ }
101
+
102
+ # Extract common name
103
+ common_patterns = [
104
+ r"common name[:\s]*([A-Za-z\s]+)",
105
+ r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+\(common\)",
106
+ r"species[:\s]*([A-Za-z\s]+)"
107
+ ]
108
+
109
+ for pattern in common_patterns:
110
+ match = re.search(pattern, response_text, re.IGNORECASE)
111
+ if match:
112
+ info["common_name"] = match.group(1).strip()
113
+ break
114
+
115
+ # Extract scientific name
116
+ sci_patterns = [
117
+ r"scientific name[:\s]*([A-Z][a-z]+\s+[a-z]+)",
118
+ r"([A-Z][a-z]+\s+[a-z]+)\s+\(scientific\)",
119
+ r"genus[:\s]*([A-Z][a-z]+)\s+species[:\s]*([a-z]+)"
120
+ ]
121
+
122
+ for pattern in sci_patterns:
123
+ match = re.search(pattern, response_text, re.IGNORECASE)
124
+ if match:
125
+ if len(match.groups()) == 2:
126
+ info["scientific_name"] = f"{match.group(1)} {match.group(2)}"
127
+ else:
128
+ info["scientific_name"] = match.group(1).strip()
129
+ break
130
+
131
+ # Extract signal type
132
+ signal_patterns = [
133
+ r"signal type[:\s]*([A-Za-z\s]+)",
134
+ r"call type[:\s]*([A-Za-z\s]+)",
135
+ r"vocalization[:\s]*([A-Za-z\s]+)",
136
+ r"sound type[:\s]*([A-Za-z\s]+)"
137
+ ]
138
+
139
+ for pattern in signal_patterns:
140
+ match = re.search(pattern, response_text, re.IGNORECASE)
141
+ if match:
142
+ info["signal_type"] = match.group(1).strip()
143
+ break
144
+
145
+ # Extract habitat
146
+ habitat_patterns = [
147
+ r"habitat[:\s]*([A-Za-z\s,]+)",
148
+ r"environment[:\s]*([A-Za-z\s,]+)",
149
+ r"found in[:\s]*([A-Za-z\s,]+)"
150
+ ]
151
+
152
+ for pattern in habitat_patterns:
153
+ match = re.search(pattern, response_text, re.IGNORECASE)
154
+ if match:
155
+ info["habitat"] = match.group(1).strip()
156
+ break
157
+
158
+ # Extract behavior
159
+ behavior_patterns = [
160
+ r"behavior[:\s]*([A-Za-z\s,]+)",
161
+ r"purpose[:\s]*([A-Za-z\s,]+)",
162
+ r"function[:\s]*([A-Za-z\s,]+)"
163
+ ]
164
+
165
+ for pattern in behavior_patterns:
166
+ match = re.search(pattern, response_text, re.IGNORECASE)
167
+ if match:
168
+ info["behavior"] = match.group(1).strip()
169
+ break
170
+
171
+ return info
172
+
173
+ def analyze_audio_characteristics(audio_path: str) -> Dict[str, Any]:
174
+ """Analyze audio characteristics using librosa"""
175
+ try:
176
+ # Load audio file
177
+ y, sr = librosa.load(audio_path, sr=None)
178
+
179
+ # Calculate audio features
180
+ duration = librosa.get_duration(y=y, sr=sr)
181
+
182
+ # Spectral features
183
+ spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
184
+ spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
185
+
186
+ # MFCC features
187
+ mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
188
+
189
+ # Pitch features
190
+ pitches, magnitudes = librosa.piptrack(y=y, sr=sr)
191
+
192
+ # Rhythm features
193
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
194
+
195
+ # Energy features
196
+ rms = librosa.feature.rms(y=y)[0]
197
+
198
+ characteristics = {
199
+ "duration_seconds": float(duration),
200
+ "sample_rate": int(sr),
201
+ "tempo_bpm": float(tempo),
202
+ "mean_spectral_centroid": float(np.mean(spectral_centroids)),
203
+ "mean_spectral_rolloff": float(np.mean(spectral_rolloff)),
204
+ "mean_rms_energy": float(np.mean(rms)),
205
+ "mfcc_mean": [float(x) for x in np.mean(mfccs, axis=1)],
206
+ "pitch_range": {
207
+ "min": float(np.min(pitches[magnitudes > 0.1]) if np.any(magnitudes > 0.1) else 0),
208
+ "max": float(np.max(pitches[magnitudes > 0.1]) if np.any(magnitudes > 0.1) else 0),
209
+ "mean": float(np.mean(pitches[magnitudes > 0.1]) if np.any(magnitudes > 0.1) else 0)
210
+ }
211
+ }
212
+
213
+ return characteristics
214
+ except Exception as e:
215
+ print(f"Error analyzing audio characteristics: {e}")
216
+ return {}
217
+
218
+ @app.post("/analyze", response_model=AnalysisResponse)
219
  async def analyze_audio(file: UploadFile = File(...)):
220
+ try:
221
+ # Save uploaded file temporarily
222
+ temp_path = f"/tmp/{file.filename}"
223
+ with open(temp_path, "wb") as buffer:
224
+ content = await file.read()
225
+ buffer.write(content)
226
+
227
+ # Analyze audio characteristics
228
+ audio_chars = analyze_audio_characteristics(temp_path)
229
+
230
+ # Create enhanced prompt for NatureLM
231
+ enhanced_prompt = f"""
232
+ Analyze this animal audio recording and provide detailed information including:
233
+
234
+ 1. Species identification (common name and scientific name)
235
+ 2. Signal type and purpose
236
+ 3. Habitat and behavior context
237
+ 4. Audio characteristics analysis
238
+ 5. Confidence level in your assessment
239
+
240
+ Please provide a comprehensive analysis with specific details about:
241
+ - Common name of the species
242
+ - Scientific name (genus and species)
243
+ - Type of vocalization (call, song, alarm, etc.)
244
+ - Habitat where this species is typically found
245
+ - Behavioral context of this sound
246
+ - Confidence level (0-100%)
247
+
248
+ Audio file: {file.filename}
249
+ Duration: {audio_chars.get('duration_seconds', 'Unknown')} seconds
250
+ Sample rate: {audio_chars.get('sample_rate', 'Unknown')} Hz
251
+ """
252
+
253
+ # Get NatureLM prediction
254
+ with torch.no_grad():
255
+ inputs = tokenizer(enhanced_prompt, return_tensors="pt")
256
+ if torch.cuda.is_available():
257
+ inputs = {k: v.cuda() for k, v in inputs.items()}
258
+
259
+ outputs = model.generate(
260
+ **inputs,
261
+ max_length=512,
262
+ temperature=0.7,
263
+ do_sample=True,
264
+ pad_token_id=tokenizer.eos_token_id
265
+ )
266
+
267
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
268
+
269
+ # Extract information from response
270
+ confidence_scores = extract_confidence_from_response(response_text)
271
+ species_info = extract_species_info(response_text)
272
+
273
+ # Calculate overall confidence
274
+ overall_confidence = max(
275
+ confidence_scores["model_confidence"],
276
+ confidence_scores["llama_confidence"],
277
+ 50.0 # Default fallback
278
+ )
279
+
280
+ # Clean up temp file
281
+ os.remove(temp_path)
282
+
283
+ return AnalysisResponse(
284
+ species=species_info["common_name"] or "Unknown species",
285
+ interpretation=response_text,
286
+ confidence=overall_confidence,
287
+ signal_type=species_info["signal_type"] or "Vocalization",
288
+ common_name=species_info["common_name"] or "Unknown",
289
+ scientific_name=species_info["scientific_name"] or "Unknown",
290
+ habitat=species_info["habitat"] or "Unknown habitat",
291
+ behavior=species_info["behavior"] or "Unknown behavior",
292
+ audio_characteristics=audio_chars,
293
+ model_confidence=confidence_scores["model_confidence"],
294
+ llama_confidence=confidence_scores["llama_confidence"],
295
+ additional_insights=response_text,
296
+ cluster_group="NatureLM Analysis"
297
+ )
298
+
299
+ except Exception as e:
300
+ # Clean up temp file if it exists
301
+ if os.path.exists(temp_path):
302
+ os.remove(temp_path)
303
+
304
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
305
+
306
+ @app.get("/health")
307
+ async def health_check():
308
+ return {"status": "healthy", "model_loaded": model is not None}
309
+
310
+ if __name__ == "__main__":
311
+ import uvicorn
312
+ uvicorn.run(app, host="0.0.0.0", port=8000)