tscr-369 commited on
Commit
c4d0b69
Β·
verified Β·
1 Parent(s): 7a02106

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -131
main.py CHANGED
@@ -9,21 +9,18 @@ from typing import Optional, Dict, Any, List
9
  import json
10
  import re
11
  from contextlib import asynccontextmanager
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoProcessor
13
  from huggingface_hub import InferenceClient
14
  import base64
15
 
16
  # Set up cache directories BEFORE importing any HuggingFace modules
17
  cache_base = "/app/.cache"
18
  os.environ['HF_HOME'] = cache_base
19
- os.environ['TRANSFORMERS_CACHE'] = f"{cache_base}/transformers"
20
  os.environ['HF_DATASETS_CACHE'] = f"{cache_base}/datasets"
21
  os.environ['HF_HUB_CACHE'] = f"{cache_base}/hub"
22
 
23
  # Ensure cache directories exist
24
  cache_dirs = [
25
  os.environ['HF_HOME'],
26
- os.environ['TRANSFORMERS_CACHE'],
27
  os.environ['HF_DATASETS_CACHE'],
28
  os.environ['HF_HUB_CACHE']
29
  ]
@@ -31,15 +28,13 @@ cache_dirs = [
31
  for cache_dir in cache_dirs:
32
  os.makedirs(cache_dir, exist_ok=True)
33
 
34
- # Global variables for model and pipeline
35
- model = None
36
- audio_pipeline = None
37
  client = None
38
 
39
  @asynccontextmanager
40
  async def lifespan(app: FastAPI):
41
  # Startup
42
- global model, audio_pipeline, client
43
  try:
44
  print("πŸ”„ Starting NatureLM Audio Decoder API...")
45
  print(f"πŸ“ Using cache directory: {os.environ.get('HF_HOME', '/app/.cache')}")
@@ -47,84 +42,7 @@ async def lifespan(app: FastAPI):
47
  # Initialize HuggingFace client for inference API
48
  client = InferenceClient()
49
  print("βœ… HuggingFace client initialized successfully")
50
-
51
- # Load NatureLM-audio model locally for better performance
52
- try:
53
- print("πŸ”„ Loading NatureLM-audio model...")
54
- model_name = "EarthSpeciesProject/NatureLM-audio"
55
-
56
- # For NatureLM-audio, we need to use a different approach since it's a custom model
57
- # Let's try using the processor and model directly
58
- try:
59
- # Load processor first
60
- processor = AutoProcessor.from_pretrained(
61
- model_name,
62
- trust_remote_code=True,
63
- cache_dir=os.environ['TRANSFORMERS_CACHE']
64
- )
65
-
66
- # Load model with specific configuration for NatureLM-audio
67
- model = AutoModelForCausalLM.from_pretrained(
68
- model_name,
69
- torch_dtype=torch.float16,
70
- device_map="auto",
71
- trust_remote_code=True,
72
- cache_dir=os.environ['TRANSFORMERS_CACHE'],
73
- low_cpu_mem_usage=True
74
- )
75
-
76
- print("βœ… NatureLM-audio model loaded successfully")
77
-
78
- # Create a custom pipeline for NatureLM-audio
79
- def naturelm_audio_pipeline(audio_input, **kwargs):
80
- """Custom pipeline for NatureLM-audio processing"""
81
- try:
82
- # Process audio with the model
83
- if isinstance(audio_input, bytes):
84
- # Convert bytes to the format expected by the model
85
- # This is a simplified approach - in practice, you'd need to match the model's expected input format
86
- inputs = processor(
87
- audio_input,
88
- return_tensors="pt",
89
- sampling_rate=16000,
90
- **kwargs
91
- )
92
- else:
93
- inputs = processor(audio_input, return_tensors="pt", **kwargs)
94
-
95
- # Generate response
96
- with torch.no_grad():
97
- outputs = model.generate(
98
- **inputs,
99
- max_length=512,
100
- do_sample=True,
101
- temperature=0.7,
102
- pad_token_id=processor.tokenizer.eos_token_id
103
- )
104
-
105
- # Decode the response
106
- response = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
- return {"text": response}
108
-
109
- except Exception as e:
110
- print(f"Error in NatureLM pipeline: {e}")
111
- return {"text": "Error processing audio with NatureLM-audio model"}
112
-
113
- audio_pipeline = naturelm_audio_pipeline
114
-
115
- except Exception as model_error:
116
- print(f"⚠️ Could not load NatureLM-audio model locally: {model_error}")
117
- print("πŸ”„ Falling back to HuggingFace Inference API")
118
- model = None
119
- audio_pipeline = None
120
-
121
- except Exception as model_error:
122
- print(f"⚠️ Could not load model locally: {model_error}")
123
- print("πŸ”„ Falling back to HuggingFace Inference API")
124
- model = None
125
- audio_pipeline = None
126
-
127
- print("βœ… API ready for NatureLM-audio analysis")
128
 
129
  except Exception as e:
130
  print(f"❌ Error during startup: {e}")
@@ -461,15 +379,14 @@ async def health_check():
461
  return {
462
  "status": "healthy",
463
  "service": "NatureLM Audio Decoder API",
464
- "model_loaded": model is not None,
465
- "pipeline_ready": audio_pipeline is not None,
466
- "client_ready": client is not None
467
  }
468
 
469
  @app.post("/analyze", response_model=AnalysisResponse)
470
  async def analyze_audio(file: UploadFile = File(...)):
471
  """
472
- Analyze audio file using NatureLM-audio model with enhanced confidence scoring and detailed captioning
473
  """
474
  try:
475
  # Save uploaded file temporarily
@@ -515,50 +432,34 @@ async def analyze_audio(file: UploadFile = File(...)):
515
  complexity=audio_chars.get('audio_quality_indicators', {}).get('complexity_score', 0)
516
  )
517
 
518
- # Use NatureLM-audio model for analysis
519
  try:
520
- if audio_pipeline is not None:
521
- # Use local model if available
522
- print("πŸ”„ Using local NatureLM-audio model...")
523
-
524
- # Read audio file
525
- with open(temp_path, "rb") as audio_file:
526
- audio_bytes = audio_file.read()
527
-
528
- # Process with local pipeline
529
- result = audio_pipeline(audio_bytes)
530
-
531
- combined_response = result.get('text', '') if isinstance(result, dict) else str(result)
532
- detection_method = "Local NatureLM-audio Model"
533
-
 
 
 
 
 
 
 
534
  else:
535
- # Use HuggingFace inference API
536
- print("πŸ”„ Using HuggingFace Inference API...")
537
-
538
- # Read audio file as bytes
539
- with open(temp_path, "rb") as audio_file:
540
- audio_bytes = audio_file.read()
541
-
542
- # Encode audio as base64 for API
543
- audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
544
-
545
- # Call NatureLM-audio model via HuggingFace API
546
- response = client.post(
547
- "EarthSpeciesProject/NatureLM-audio",
548
- inputs={
549
- "audio": audio_b64,
550
- "text": prompt
551
- }
552
- )
553
-
554
- # Parse response
555
- if isinstance(response, list) and len(response) > 0:
556
- combined_response = response[0]
557
- else:
558
- combined_response = str(response)
559
-
560
- detection_method = "HuggingFace Inference API"
561
-
562
  except Exception as api_error:
563
  print(f"API call failed: {api_error}")
564
  # Fallback to a comprehensive mock response for testing
 
9
  import json
10
  import re
11
  from contextlib import asynccontextmanager
 
12
  from huggingface_hub import InferenceClient
13
  import base64
14
 
15
  # Set up cache directories BEFORE importing any HuggingFace modules
16
  cache_base = "/app/.cache"
17
  os.environ['HF_HOME'] = cache_base
 
18
  os.environ['HF_DATASETS_CACHE'] = f"{cache_base}/datasets"
19
  os.environ['HF_HUB_CACHE'] = f"{cache_base}/hub"
20
 
21
  # Ensure cache directories exist
22
  cache_dirs = [
23
  os.environ['HF_HOME'],
 
24
  os.environ['HF_DATASETS_CACHE'],
25
  os.environ['HF_HUB_CACHE']
26
  ]
 
28
  for cache_dir in cache_dirs:
29
  os.makedirs(cache_dir, exist_ok=True)
30
 
31
+ # Global variables
 
 
32
  client = None
33
 
34
  @asynccontextmanager
35
  async def lifespan(app: FastAPI):
36
  # Startup
37
+ global client
38
  try:
39
  print("πŸ”„ Starting NatureLM Audio Decoder API...")
40
  print(f"πŸ“ Using cache directory: {os.environ.get('HF_HOME', '/app/.cache')}")
 
42
  # Initialize HuggingFace client for inference API
43
  client = InferenceClient()
44
  print("βœ… HuggingFace client initialized successfully")
45
+ print("βœ… API ready for NatureLM-audio analysis via HuggingFace Inference API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  except Exception as e:
48
  print(f"❌ Error during startup: {e}")
 
379
  return {
380
  "status": "healthy",
381
  "service": "NatureLM Audio Decoder API",
382
+ "client_ready": client is not None,
383
+ "model": "NatureLM-audio via HuggingFace Inference API"
 
384
  }
385
 
386
  @app.post("/analyze", response_model=AnalysisResponse)
387
  async def analyze_audio(file: UploadFile = File(...)):
388
  """
389
+ Analyze audio file using NatureLM-audio model via HuggingFace Inference API
390
  """
391
  try:
392
  # Save uploaded file temporarily
 
432
  complexity=audio_chars.get('audio_quality_indicators', {}).get('complexity_score', 0)
433
  )
434
 
435
+ # Use HuggingFace Inference API for NatureLM-audio
436
  try:
437
+ print("πŸ”„ Using HuggingFace Inference API for NatureLM-audio...")
438
+
439
+ # Read audio file as bytes
440
+ with open(temp_path, "rb") as audio_file:
441
+ audio_bytes = audio_file.read()
442
+
443
+ # Encode audio as base64 for API
444
+ audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
445
+
446
+ # Call NatureLM-audio model via HuggingFace API
447
+ response = client.post(
448
+ "EarthSpeciesProject/NatureLM-audio",
449
+ inputs={
450
+ "audio": audio_b64,
451
+ "text": prompt
452
+ }
453
+ )
454
+
455
+ # Parse response
456
+ if isinstance(response, list) and len(response) > 0:
457
+ combined_response = response[0]
458
  else:
459
+ combined_response = str(response)
460
+
461
+ detection_method = "HuggingFace Inference API"
462
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  except Exception as api_error:
464
  print(f"API call failed: {api_error}")
465
  # Fallback to a comprehensive mock response for testing