uparekh01151 commited on
Commit
c1c187b
·
1 Parent(s): 4f99f28

Replace manual HTTP requests with InferenceClient for better reliability and error handling

Browse files
Files changed (1) hide show
  1. src/models_registry.py +24 -34
src/models_registry.py CHANGED
@@ -5,10 +5,10 @@ Optimized for remote inference without local model loading.
5
 
6
  import yaml
7
  import os
8
- import requests
9
  from typing import List, Dict, Any, Optional
10
  from dataclasses import dataclass
11
  import sys
 
12
 
13
  # Add src to path for imports
14
  sys.path.append('src')
@@ -70,48 +70,38 @@ class ModelsRegistry:
70
 
71
 
72
  class HuggingFaceInference:
73
- """Interface for Hugging Face Inference API."""
74
 
75
  def __init__(self, api_token: Optional[str] = None):
76
  self.api_token = api_token or os.getenv("HF_TOKEN")
77
- self.base_url = "https://api-inference.huggingface.co/models"
 
78
 
79
  def generate(self, model_id: str, prompt: str, params: Dict[str, Any]) -> str:
80
  """Generate text using Hugging Face Inference API."""
81
- headers = {}
82
- if self.api_token:
83
- headers["Authorization"] = f"Bearer {self.api_token}"
84
-
85
- payload = {
86
- "inputs": prompt,
87
- "parameters": params
88
- }
89
-
90
  try:
91
- response = requests.post(
92
- f"{self.base_url}/{model_id}",
93
- headers=headers,
94
- json=payload,
95
- timeout=60
 
 
 
96
  )
97
-
98
- if response.status_code != 200:
99
- raise Exception(f"Hugging Face API error: {response.status_code} - {response.text}")
100
-
101
- result = response.json()
102
-
103
- # Handle different response formats
104
- if isinstance(result, list) and len(result) > 0:
105
- return result[0].get('generated_text', '')
106
- elif isinstance(result, dict):
107
- return result.get('generated_text', '')
108
- else:
109
- return str(result)
110
 
111
- except requests.exceptions.Timeout:
112
- raise Exception("Request timeout - model may be loading. Please try again in a moment.")
113
- except requests.exceptions.RequestException as e:
114
- raise Exception(f"Network error: {str(e)}")
 
 
 
 
 
 
115
 
116
 
117
  class ModelInterface:
 
5
 
6
  import yaml
7
  import os
 
8
  from typing import List, Dict, Any, Optional
9
  from dataclasses import dataclass
10
  import sys
11
+ from huggingface_hub import InferenceClient
12
 
13
  # Add src to path for imports
14
  sys.path.append('src')
 
70
 
71
 
72
  class HuggingFaceInference:
73
+ """Interface for Hugging Face Inference API using InferenceClient."""
74
 
75
  def __init__(self, api_token: Optional[str] = None):
76
  self.api_token = api_token or os.getenv("HF_TOKEN")
77
+ # InferenceClient handles authentication automatically
78
+ self.client = InferenceClient(token=self.api_token)
79
 
80
  def generate(self, model_id: str, prompt: str, params: Dict[str, Any]) -> str:
81
  """Generate text using Hugging Face Inference API."""
 
 
 
 
 
 
 
 
 
82
  try:
83
+ # Much simpler API call with InferenceClient!
84
+ result = self.client.text_generation(
85
+ prompt=prompt,
86
+ model=model_id,
87
+ max_new_tokens=params.get('max_new_tokens', 128),
88
+ temperature=params.get('temperature', 0.1),
89
+ top_p=params.get('top_p', 0.9),
90
+ return_full_text=False # Only return the generated part
91
  )
92
+
93
+ return result
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ except Exception as e:
96
+ # InferenceClient provides better error messages
97
+ if "404" in str(e):
98
+ raise Exception(f"Model not found: {model_id}")
99
+ elif "401" in str(e):
100
+ raise Exception(f"Authentication failed - check HF_TOKEN")
101
+ elif "503" in str(e):
102
+ raise Exception(f"Model {model_id} is loading, please try again in a moment")
103
+ else:
104
+ raise Exception(f"Hugging Face API error: {str(e)}")
105
 
106
 
107
  class ModelInterface: