Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -8,9 +8,9 @@ from pydantic import BaseModel
|
|
| 8 |
from typing import Optional, Dict, Any
|
| 9 |
import json
|
| 10 |
import re
|
| 11 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
-
|
| 13 |
-
import
|
| 14 |
|
| 15 |
app = FastAPI(title="NatureLM Audio Analysis API")
|
| 16 |
|
|
@@ -23,34 +23,24 @@ app.add_middleware(
|
|
| 23 |
allow_headers=["*"],
|
| 24 |
)
|
| 25 |
|
| 26 |
-
# Initialize model
|
| 27 |
model = None
|
| 28 |
-
|
| 29 |
-
audio_pipeline = None
|
| 30 |
|
| 31 |
def load_model():
|
| 32 |
-
global model,
|
| 33 |
try:
|
| 34 |
-
# Load
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
torch_dtype=torch.float16,
|
| 40 |
-
device_map="auto",
|
| 41 |
-
trust_remote_code=True
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
# Load audio analysis pipeline
|
| 45 |
-
audio_pipeline = pipeline(
|
| 46 |
-
"audio-classification",
|
| 47 |
-
model="microsoft/wavlm-base",
|
| 48 |
-
return_all_scores=True
|
| 49 |
-
)
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
except Exception as e:
|
| 53 |
-
print(f"β Error loading
|
| 54 |
raise e
|
| 55 |
|
| 56 |
# Load model on startup
|
|
@@ -114,7 +104,10 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
|
|
| 114 |
common_patterns = [
|
| 115 |
r"common name[:\s]*([A-Za-z\s]+)",
|
| 116 |
r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+\(common\)",
|
| 117 |
-
r"species[:\s]*([A-Za-z\s]+)"
|
|
|
|
|
|
|
|
|
|
| 118 |
]
|
| 119 |
|
| 120 |
for pattern in common_patterns:
|
|
@@ -144,7 +137,9 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
|
|
| 144 |
r"signal type[:\s]*([A-Za-z\s]+)",
|
| 145 |
r"call type[:\s]*([A-Za-z\s]+)",
|
| 146 |
r"vocalization[:\s]*([A-Za-z\s]+)",
|
| 147 |
-
r"sound type[:\s]*([A-Za-z\s]+)"
|
|
|
|
|
|
|
| 148 |
]
|
| 149 |
|
| 150 |
for pattern in signal_patterns:
|
|
@@ -157,7 +152,8 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
|
|
| 157 |
habitat_patterns = [
|
| 158 |
r"habitat[:\s]*([A-Za-z\s,]+)",
|
| 159 |
r"environment[:\s]*([A-Za-z\s,]+)",
|
| 160 |
-
r"found in[:\s]*([A-Za-z\s,]+)"
|
|
|
|
| 161 |
]
|
| 162 |
|
| 163 |
for pattern in habitat_patterns:
|
|
@@ -170,7 +166,8 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
|
|
| 170 |
behavior_patterns = [
|
| 171 |
r"behavior[:\s]*([A-Za-z\s,]+)",
|
| 172 |
r"purpose[:\s]*([A-Za-z\s,]+)",
|
| 173 |
-
r"function[:\s]*([A-Za-z\s,]+)"
|
|
|
|
| 174 |
]
|
| 175 |
|
| 176 |
for pattern in behavior_patterns:
|
|
@@ -181,12 +178,11 @@ def extract_species_info(response_text: str) -> Dict[str, str]:
|
|
| 181 |
|
| 182 |
return info
|
| 183 |
|
| 184 |
-
def analyze_audio_characteristics(
|
| 185 |
"""Analyze audio characteristics using librosa"""
|
| 186 |
try:
|
| 187 |
-
# Load audio
|
| 188 |
-
|
| 189 |
-
y, sr = librosa.load(audio_bytes, sr=None)
|
| 190 |
|
| 191 |
# Calculate audio features
|
| 192 |
duration = librosa.get_duration(y=y, sr=sr)
|
|
@@ -227,100 +223,51 @@ def analyze_audio_characteristics(audio_data: bytes) -> Dict[str, Any]:
|
|
| 227 |
print(f"Error analyzing audio characteristics: {e}")
|
| 228 |
return {}
|
| 229 |
|
| 230 |
-
def classify_audio_signal(audio_data: bytes) -> Dict[str, Any]:
|
| 231 |
-
"""Classify audio using WavLM model"""
|
| 232 |
-
try:
|
| 233 |
-
# Convert audio to the format expected by WavLM
|
| 234 |
-
audio_bytes = io.BytesIO(audio_data)
|
| 235 |
-
y, sr = librosa.load(audio_bytes, sr=16000) # WavLM expects 16kHz
|
| 236 |
-
|
| 237 |
-
# Reshape for the pipeline
|
| 238 |
-
audio_input = {"array": y, "sampling_rate": sr}
|
| 239 |
-
|
| 240 |
-
# Get classification results
|
| 241 |
-
results = audio_pipeline(audio_input)
|
| 242 |
-
|
| 243 |
-
# Extract the most likely class and confidence
|
| 244 |
-
if results and len(results) > 0:
|
| 245 |
-
top_result = results[0]
|
| 246 |
-
return {
|
| 247 |
-
"signal_type": top_result.get("label", "Unknown"),
|
| 248 |
-
"confidence": top_result.get("score", 0.0) * 100
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
return {"signal_type": "Unknown", "confidence": 0.0}
|
| 252 |
-
except Exception as e:
|
| 253 |
-
print(f"Error in audio classification: {e}")
|
| 254 |
-
return {"signal_type": "Unknown", "confidence": 0.0}
|
| 255 |
-
|
| 256 |
@app.post("/analyze", response_model=AnalysisResponse)
|
| 257 |
async def analyze_audio(file: UploadFile = File(...)):
|
| 258 |
try:
|
| 259 |
-
#
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
# Analyze audio characteristics
|
| 263 |
-
audio_chars = analyze_audio_characteristics(
|
| 264 |
|
| 265 |
-
#
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
#
|
| 269 |
-
|
| 270 |
-
You are an expert in animal vocalization analysis. Analyze this audio recording and provide detailed information.
|
| 271 |
-
|
| 272 |
-
Audio file: {file.filename}
|
| 273 |
-
Duration: {audio_chars.get('duration_seconds', 'Unknown')} seconds
|
| 274 |
-
Sample rate: {audio_chars.get('sample_rate', 'Unknown')} Hz
|
| 275 |
-
Tempo: {audio_chars.get('tempo_bpm', 'Unknown')} BPM
|
| 276 |
-
Signal classification: {signal_classification.get('signal_type', 'Unknown')}
|
| 277 |
-
|
| 278 |
-
Please provide a comprehensive analysis including:
|
| 279 |
-
1. Species identification (common name and scientific name if possible)
|
| 280 |
-
2. Signal type and purpose (mating call, alarm, territorial, etc.)
|
| 281 |
-
3. Habitat and behavior context
|
| 282 |
-
4. Confidence level in your assessment (0-100%)
|
| 283 |
-
|
| 284 |
-
Format your response with clear sections for each aspect.
|
| 285 |
-
"""
|
| 286 |
|
| 287 |
-
#
|
| 288 |
-
|
| 289 |
-
inputs = tokenizer(enhanced_prompt, return_tensors="pt", max_length=512, truncation=True)
|
| 290 |
-
|
| 291 |
-
if torch.cuda.is_available():
|
| 292 |
-
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 293 |
-
|
| 294 |
-
outputs = model.generate(
|
| 295 |
-
**inputs,
|
| 296 |
-
max_length=1024,
|
| 297 |
-
temperature=0.7,
|
| 298 |
-
do_sample=True,
|
| 299 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 300 |
-
eos_token_id=tokenizer.eos_token_id
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 304 |
-
# Remove the input prompt from the response
|
| 305 |
-
response_text = response_text.replace(enhanced_prompt, "").strip()
|
| 306 |
|
| 307 |
# Extract information from response
|
| 308 |
-
confidence_scores = extract_confidence_from_response(
|
| 309 |
-
species_info = extract_species_info(
|
| 310 |
|
| 311 |
-
# Calculate overall confidence
|
| 312 |
overall_confidence = max(
|
| 313 |
confidence_scores["model_confidence"],
|
| 314 |
confidence_scores["llama_confidence"],
|
| 315 |
-
|
| 316 |
-
50.0 # Default fallback
|
| 317 |
)
|
| 318 |
|
|
|
|
|
|
|
|
|
|
| 319 |
return AnalysisResponse(
|
| 320 |
species=species_info["common_name"] or "Unknown species",
|
| 321 |
-
interpretation=
|
| 322 |
confidence=overall_confidence,
|
| 323 |
-
signal_type=species_info["signal_type"] or
|
| 324 |
common_name=species_info["common_name"] or "Unknown",
|
| 325 |
scientific_name=species_info["scientific_name"] or "Unknown",
|
| 326 |
habitat=species_info["habitat"] or "Unknown habitat",
|
|
@@ -328,11 +275,15 @@ async def analyze_audio(file: UploadFile = File(...)):
|
|
| 328 |
audio_characteristics=audio_chars,
|
| 329 |
model_confidence=confidence_scores["model_confidence"],
|
| 330 |
llama_confidence=confidence_scores["llama_confidence"],
|
| 331 |
-
additional_insights=
|
| 332 |
cluster_group="NatureLM Analysis"
|
| 333 |
)
|
| 334 |
|
| 335 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 337 |
|
| 338 |
@app.get("/health")
|
|
|
|
| 8 |
from typing import Optional, Dict, Any
|
| 9 |
import json
|
| 10 |
import re
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
+
from NatureLM.models import NatureLM
|
| 13 |
+
from NatureLM.infer import Pipeline
|
| 14 |
|
| 15 |
app = FastAPI(title="NatureLM Audio Analysis API")
|
| 16 |
|
|
|
|
| 23 |
allow_headers=["*"],
|
| 24 |
)
|
| 25 |
|
| 26 |
+
# Initialize NatureLM model
|
| 27 |
model = None
|
| 28 |
+
pipeline = None
|
|
|
|
| 29 |
|
| 30 |
def load_model():
|
| 31 |
+
global model, pipeline
|
| 32 |
try:
|
| 33 |
+
# Load NatureLM-audio model from HuggingFace
|
| 34 |
+
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 35 |
+
model = model.eval()
|
| 36 |
+
if torch.cuda.is_available():
|
| 37 |
+
model = model.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Initialize pipeline
|
| 40 |
+
pipeline = Pipeline(model=model)
|
| 41 |
+
print("β
NatureLM model loaded successfully")
|
| 42 |
except Exception as e:
|
| 43 |
+
print(f"β Error loading model: {e}")
|
| 44 |
raise e
|
| 45 |
|
| 46 |
# Load model on startup
|
|
|
|
| 104 |
common_patterns = [
|
| 105 |
r"common name[:\s]*([A-Za-z\s]+)",
|
| 106 |
r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+\(common\)",
|
| 107 |
+
r"species[:\s]*([A-Za-z\s]+)",
|
| 108 |
+
r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+treefrog",
|
| 109 |
+
r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+bird",
|
| 110 |
+
r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+mammal"
|
| 111 |
]
|
| 112 |
|
| 113 |
for pattern in common_patterns:
|
|
|
|
| 137 |
r"signal type[:\s]*([A-Za-z\s]+)",
|
| 138 |
r"call type[:\s]*([A-Za-z\s]+)",
|
| 139 |
r"vocalization[:\s]*([A-Za-z\s]+)",
|
| 140 |
+
r"sound type[:\s]*([A-Za-z\s]+)",
|
| 141 |
+
r"([A-Za-z\s]+)\s+call",
|
| 142 |
+
r"([A-Za-z\s]+)\s+song"
|
| 143 |
]
|
| 144 |
|
| 145 |
for pattern in signal_patterns:
|
|
|
|
| 152 |
habitat_patterns = [
|
| 153 |
r"habitat[:\s]*([A-Za-z\s,]+)",
|
| 154 |
r"environment[:\s]*([A-Za-z\s,]+)",
|
| 155 |
+
r"found in[:\s]*([A-Za-z\s,]+)",
|
| 156 |
+
r"lives in[:\s]*([A-Za-z\s,]+)"
|
| 157 |
]
|
| 158 |
|
| 159 |
for pattern in habitat_patterns:
|
|
|
|
| 166 |
behavior_patterns = [
|
| 167 |
r"behavior[:\s]*([A-Za-z\s,]+)",
|
| 168 |
r"purpose[:\s]*([A-Za-z\s,]+)",
|
| 169 |
+
r"function[:\s]*([A-Za-z\s,]+)",
|
| 170 |
+
r"used for[:\s]*([A-Za-z\s,]+)"
|
| 171 |
]
|
| 172 |
|
| 173 |
for pattern in behavior_patterns:
|
|
|
|
| 178 |
|
| 179 |
return info
|
| 180 |
|
| 181 |
+
def analyze_audio_characteristics(audio_path: str) -> Dict[str, Any]:
|
| 182 |
"""Analyze audio characteristics using librosa"""
|
| 183 |
try:
|
| 184 |
+
# Load audio file
|
| 185 |
+
y, sr = librosa.load(audio_path, sr=None)
|
|
|
|
| 186 |
|
| 187 |
# Calculate audio features
|
| 188 |
duration = librosa.get_duration(y=y, sr=sr)
|
|
|
|
| 223 |
print(f"Error analyzing audio characteristics: {e}")
|
| 224 |
return {}
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
@app.post("/analyze", response_model=AnalysisResponse)
|
| 227 |
async def analyze_audio(file: UploadFile = File(...)):
|
| 228 |
try:
|
| 229 |
+
# Save uploaded file temporarily
|
| 230 |
+
temp_path = f"/tmp/{file.filename}"
|
| 231 |
+
with open(temp_path, "wb") as buffer:
|
| 232 |
+
content = await file.read()
|
| 233 |
+
buffer.write(content)
|
| 234 |
|
| 235 |
# Analyze audio characteristics
|
| 236 |
+
audio_chars = analyze_audio_characteristics(temp_path)
|
| 237 |
|
| 238 |
+
# Create multiple queries for comprehensive analysis
|
| 239 |
+
queries = [
|
| 240 |
+
"What is the common name for the focal species in the audio? Answer:",
|
| 241 |
+
"What type of vocalization or call is this? Answer:",
|
| 242 |
+
"Describe the habitat and behavior context of this species. Answer:",
|
| 243 |
+
"Provide a detailed analysis of this animal sound including species identification, call type, and behavioral context. Answer:"
|
| 244 |
+
]
|
| 245 |
|
| 246 |
+
# Run NatureLM analysis
|
| 247 |
+
results = pipeline([temp_path], queries, window_length_seconds=10.0, hop_length_seconds=10.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
# Combine results
|
| 250 |
+
combined_response = " ".join(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
# Extract information from response
|
| 253 |
+
confidence_scores = extract_confidence_from_response(combined_response)
|
| 254 |
+
species_info = extract_species_info(combined_response)
|
| 255 |
|
| 256 |
+
# Calculate overall confidence based on response quality
|
| 257 |
overall_confidence = max(
|
| 258 |
confidence_scores["model_confidence"],
|
| 259 |
confidence_scores["llama_confidence"],
|
| 260 |
+
75.0 if species_info["common_name"] else 50.0 # Higher confidence if species identified
|
|
|
|
| 261 |
)
|
| 262 |
|
| 263 |
+
# Clean up temp file
|
| 264 |
+
os.remove(temp_path)
|
| 265 |
+
|
| 266 |
return AnalysisResponse(
|
| 267 |
species=species_info["common_name"] or "Unknown species",
|
| 268 |
+
interpretation=combined_response,
|
| 269 |
confidence=overall_confidence,
|
| 270 |
+
signal_type=species_info["signal_type"] or "Vocalization",
|
| 271 |
common_name=species_info["common_name"] or "Unknown",
|
| 272 |
scientific_name=species_info["scientific_name"] or "Unknown",
|
| 273 |
habitat=species_info["habitat"] or "Unknown habitat",
|
|
|
|
| 275 |
audio_characteristics=audio_chars,
|
| 276 |
model_confidence=confidence_scores["model_confidence"],
|
| 277 |
llama_confidence=confidence_scores["llama_confidence"],
|
| 278 |
+
additional_insights=combined_response,
|
| 279 |
cluster_group="NatureLM Analysis"
|
| 280 |
)
|
| 281 |
|
| 282 |
except Exception as e:
|
| 283 |
+
# Clean up temp file if it exists
|
| 284 |
+
if os.path.exists(temp_path):
|
| 285 |
+
os.remove(temp_path)
|
| 286 |
+
|
| 287 |
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 288 |
|
| 289 |
@app.get("/health")
|