tscr-369 commited on
Commit
20189cf
Β·
verified Β·
1 Parent(s): c715944

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +60 -109
main.py CHANGED
@@ -8,9 +8,9 @@ from pydantic import BaseModel
8
  from typing import Optional, Dict, Any
9
  import json
10
  import re
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
- import soundfile as sf
13
- import io
14
 
15
  app = FastAPI(title="NatureLM Audio Analysis API")
16
 
@@ -23,34 +23,24 @@ app.add_middleware(
23
  allow_headers=["*"],
24
  )
25
 
26
- # Initialize model and tokenizer
27
  model = None
28
- tokenizer = None
29
- audio_pipeline = None
30
 
31
  def load_model():
32
- global model, tokenizer, audio_pipeline
33
  try:
34
- # Load the Llama model that NatureLM uses
35
- model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float16,
40
- device_map="auto",
41
- trust_remote_code=True
42
- )
43
-
44
- # Load audio analysis pipeline
45
- audio_pipeline = pipeline(
46
- "audio-classification",
47
- model="microsoft/wavlm-base",
48
- return_all_scores=True
49
- )
50
 
51
- print("βœ… Models loaded successfully")
 
 
52
  except Exception as e:
53
- print(f"❌ Error loading models: {e}")
54
  raise e
55
 
56
  # Load model on startup
@@ -114,7 +104,10 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
114
  common_patterns = [
115
  r"common name[:\s]*([A-Za-z\s]+)",
116
  r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+\(common\)",
117
- r"species[:\s]*([A-Za-z\s]+)"
 
 
 
118
  ]
119
 
120
  for pattern in common_patterns:
@@ -144,7 +137,9 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
144
  r"signal type[:\s]*([A-Za-z\s]+)",
145
  r"call type[:\s]*([A-Za-z\s]+)",
146
  r"vocalization[:\s]*([A-Za-z\s]+)",
147
- r"sound type[:\s]*([A-Za-z\s]+)"
 
 
148
  ]
149
 
150
  for pattern in signal_patterns:
@@ -157,7 +152,8 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
157
  habitat_patterns = [
158
  r"habitat[:\s]*([A-Za-z\s,]+)",
159
  r"environment[:\s]*([A-Za-z\s,]+)",
160
- r"found in[:\s]*([A-Za-z\s,]+)"
 
161
  ]
162
 
163
  for pattern in habitat_patterns:
@@ -170,7 +166,8 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
170
  behavior_patterns = [
171
  r"behavior[:\s]*([A-Za-z\s,]+)",
172
  r"purpose[:\s]*([A-Za-z\s,]+)",
173
- r"function[:\s]*([A-Za-z\s,]+)"
 
174
  ]
175
 
176
  for pattern in behavior_patterns:
@@ -181,12 +178,11 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
181
 
182
  return info
183
 
184
- def analyze_audio_characteristics(audio_data: bytes) -> Dict[str, Any]:
185
  """Analyze audio characteristics using librosa"""
186
  try:
187
- # Load audio from bytes
188
- audio_bytes = io.BytesIO(audio_data)
189
- y, sr = librosa.load(audio_bytes, sr=None)
190
 
191
  # Calculate audio features
192
  duration = librosa.get_duration(y=y, sr=sr)
@@ -227,100 +223,51 @@ def analyze_audio_characteristics(audio_data: bytes) -> Dict[str, Any]:
227
  print(f"Error analyzing audio characteristics: {e}")
228
  return {}
229
 
230
- def classify_audio_signal(audio_data: bytes) -> Dict[str, Any]:
231
- """Classify audio using WavLM model"""
232
- try:
233
- # Convert audio to the format expected by WavLM
234
- audio_bytes = io.BytesIO(audio_data)
235
- y, sr = librosa.load(audio_bytes, sr=16000) # WavLM expects 16kHz
236
-
237
- # Reshape for the pipeline
238
- audio_input = {"array": y, "sampling_rate": sr}
239
-
240
- # Get classification results
241
- results = audio_pipeline(audio_input)
242
-
243
- # Extract the most likely class and confidence
244
- if results and len(results) > 0:
245
- top_result = results[0]
246
- return {
247
- "signal_type": top_result.get("label", "Unknown"),
248
- "confidence": top_result.get("score", 0.0) * 100
249
- }
250
-
251
- return {"signal_type": "Unknown", "confidence": 0.0}
252
- except Exception as e:
253
- print(f"Error in audio classification: {e}")
254
- return {"signal_type": "Unknown", "confidence": 0.0}
255
-
256
  @app.post("/analyze", response_model=AnalysisResponse)
257
  async def analyze_audio(file: UploadFile = File(...)):
258
  try:
259
- # Read file content
260
- content = await file.read()
 
 
 
261
 
262
  # Analyze audio characteristics
263
- audio_chars = analyze_audio_characteristics(content)
264
 
265
- # Classify audio signal
266
- signal_classification = classify_audio_signal(content)
 
 
 
 
 
267
 
268
- # Create enhanced prompt for Llama model
269
- enhanced_prompt = f"""
270
- You are an expert in animal vocalization analysis. Analyze this audio recording and provide detailed information.
271
-
272
- Audio file: {file.filename}
273
- Duration: {audio_chars.get('duration_seconds', 'Unknown')} seconds
274
- Sample rate: {audio_chars.get('sample_rate', 'Unknown')} Hz
275
- Tempo: {audio_chars.get('tempo_bpm', 'Unknown')} BPM
276
- Signal classification: {signal_classification.get('signal_type', 'Unknown')}
277
-
278
- Please provide a comprehensive analysis including:
279
- 1. Species identification (common name and scientific name if possible)
280
- 2. Signal type and purpose (mating call, alarm, territorial, etc.)
281
- 3. Habitat and behavior context
282
- 4. Confidence level in your assessment (0-100%)
283
-
284
- Format your response with clear sections for each aspect.
285
- """
286
 
287
- # Get Llama model prediction
288
- with torch.no_grad():
289
- inputs = tokenizer(enhanced_prompt, return_tensors="pt", max_length=512, truncation=True)
290
-
291
- if torch.cuda.is_available():
292
- inputs = {k: v.cuda() for k, v in inputs.items()}
293
-
294
- outputs = model.generate(
295
- **inputs,
296
- max_length=1024,
297
- temperature=0.7,
298
- do_sample=True,
299
- pad_token_id=tokenizer.eos_token_id,
300
- eos_token_id=tokenizer.eos_token_id
301
- )
302
-
303
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
304
- # Remove the input prompt from the response
305
- response_text = response_text.replace(enhanced_prompt, "").strip()
306
 
307
  # Extract information from response
308
- confidence_scores = extract_confidence_from_response(response_text)
309
- species_info = extract_species_info(response_text)
310
 
311
- # Calculate overall confidence
312
  overall_confidence = max(
313
  confidence_scores["model_confidence"],
314
  confidence_scores["llama_confidence"],
315
- signal_classification.get("confidence", 0.0),
316
- 50.0 # Default fallback
317
  )
318
 
 
 
 
319
  return AnalysisResponse(
320
  species=species_info["common_name"] or "Unknown species",
321
- interpretation=response_text,
322
  confidence=overall_confidence,
323
- signal_type=species_info["signal_type"] or signal_classification.get("signal_type", "Vocalization"),
324
  common_name=species_info["common_name"] or "Unknown",
325
  scientific_name=species_info["scientific_name"] or "Unknown",
326
  habitat=species_info["habitat"] or "Unknown habitat",
@@ -328,11 +275,15 @@ async def analyze_audio(file: UploadFile = File(...)):
328
  audio_characteristics=audio_chars,
329
  model_confidence=confidence_scores["model_confidence"],
330
  llama_confidence=confidence_scores["llama_confidence"],
331
- additional_insights=response_text,
332
  cluster_group="NatureLM Analysis"
333
  )
334
 
335
  except Exception as e:
 
 
 
 
336
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
337
 
338
  @app.get("/health")
 
8
  from typing import Optional, Dict, Any
9
  import json
10
  import re
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from NatureLM.models import NatureLM
13
+ from NatureLM.infer import Pipeline
14
 
15
  app = FastAPI(title="NatureLM Audio Analysis API")
16
 
 
23
  allow_headers=["*"],
24
  )
25
 
26
+ # Initialize NatureLM model
27
  model = None
28
+ pipeline = None
 
29
 
30
  def load_model():
31
+ global model, pipeline
32
  try:
33
+ # Load NatureLM-audio model from HuggingFace
34
+ model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
35
+ model = model.eval()
36
+ if torch.cuda.is_available():
37
+ model = model.cuda()
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Initialize pipeline
40
+ pipeline = Pipeline(model=model)
41
+ print("βœ… NatureLM model loaded successfully")
42
  except Exception as e:
43
+ print(f"❌ Error loading model: {e}")
44
  raise e
45
 
46
  # Load model on startup
 
104
  common_patterns = [
105
  r"common name[:\s]*([A-Za-z\s]+)",
106
  r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+\(common\)",
107
+ r"species[:\s]*([A-Za-z\s]+)",
108
+ r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+treefrog",
109
+ r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+bird",
110
+ r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+mammal"
111
  ]
112
 
113
  for pattern in common_patterns:
 
137
  r"signal type[:\s]*([A-Za-z\s]+)",
138
  r"call type[:\s]*([A-Za-z\s]+)",
139
  r"vocalization[:\s]*([A-Za-z\s]+)",
140
+ r"sound type[:\s]*([A-Za-z\s]+)",
141
+ r"([A-Za-z\s]+)\s+call",
142
+ r"([A-Za-z\s]+)\s+song"
143
  ]
144
 
145
  for pattern in signal_patterns:
 
152
  habitat_patterns = [
153
  r"habitat[:\s]*([A-Za-z\s,]+)",
154
  r"environment[:\s]*([A-Za-z\s,]+)",
155
+ r"found in[:\s]*([A-Za-z\s,]+)",
156
+ r"lives in[:\s]*([A-Za-z\s,]+)"
157
  ]
158
 
159
  for pattern in habitat_patterns:
 
166
  behavior_patterns = [
167
  r"behavior[:\s]*([A-Za-z\s,]+)",
168
  r"purpose[:\s]*([A-Za-z\s,]+)",
169
+ r"function[:\s]*([A-Za-z\s,]+)",
170
+ r"used for[:\s]*([A-Za-z\s,]+)"
171
  ]
172
 
173
  for pattern in behavior_patterns:
 
178
 
179
  return info
180
 
181
+ def analyze_audio_characteristics(audio_path: str) -> Dict[str, Any]:
182
  """Analyze audio characteristics using librosa"""
183
  try:
184
+ # Load audio file
185
+ y, sr = librosa.load(audio_path, sr=None)
 
186
 
187
  # Calculate audio features
188
  duration = librosa.get_duration(y=y, sr=sr)
 
223
  print(f"Error analyzing audio characteristics: {e}")
224
  return {}
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  @app.post("/analyze", response_model=AnalysisResponse)
227
  async def analyze_audio(file: UploadFile = File(...)):
228
  try:
229
+ # Save uploaded file temporarily
230
+ temp_path = f"/tmp/{file.filename}"
231
+ with open(temp_path, "wb") as buffer:
232
+ content = await file.read()
233
+ buffer.write(content)
234
 
235
  # Analyze audio characteristics
236
+ audio_chars = analyze_audio_characteristics(temp_path)
237
 
238
+ # Create multiple queries for comprehensive analysis
239
+ queries = [
240
+ "What is the common name for the focal species in the audio? Answer:",
241
+ "What type of vocalization or call is this? Answer:",
242
+ "Describe the habitat and behavior context of this species. Answer:",
243
+ "Provide a detailed analysis of this animal sound including species identification, call type, and behavioral context. Answer:"
244
+ ]
245
 
246
+ # Run NatureLM analysis
247
+ results = pipeline([temp_path], queries, window_length_seconds=10.0, hop_length_seconds=10.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # Combine results
250
+ combined_response = " ".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  # Extract information from response
253
+ confidence_scores = extract_confidence_from_response(combined_response)
254
+ species_info = extract_species_info(combined_response)
255
 
256
+ # Calculate overall confidence based on response quality
257
  overall_confidence = max(
258
  confidence_scores["model_confidence"],
259
  confidence_scores["llama_confidence"],
260
+ 75.0 if species_info["common_name"] else 50.0 # Higher confidence if species identified
 
261
  )
262
 
263
+ # Clean up temp file
264
+ os.remove(temp_path)
265
+
266
  return AnalysisResponse(
267
  species=species_info["common_name"] or "Unknown species",
268
+ interpretation=combined_response,
269
  confidence=overall_confidence,
270
+ signal_type=species_info["signal_type"] or "Vocalization",
271
  common_name=species_info["common_name"] or "Unknown",
272
  scientific_name=species_info["scientific_name"] or "Unknown",
273
  habitat=species_info["habitat"] or "Unknown habitat",
 
275
  audio_characteristics=audio_chars,
276
  model_confidence=confidence_scores["model_confidence"],
277
  llama_confidence=confidence_scores["llama_confidence"],
278
+ additional_insights=combined_response,
279
  cluster_group="NatureLM Analysis"
280
  )
281
 
282
  except Exception as e:
283
+ # Clean up temp file if it exists
284
+ if os.path.exists(temp_path):
285
+ os.remove(temp_path)
286
+
287
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
288
 
289
  @app.get("/health")