uparekh01151 commited on
Commit
b16182c
·
1 Parent(s): 6436cd9

Switch to Llama-3.1-8B-Instruct with Nebius provider using chat.completions.create method

Browse files
Files changed (2) hide show
  1. config/models.yaml +5 -5
  2. src/models_registry.py +18 -2
config/models.yaml CHANGED
@@ -1,10 +1,10 @@
1
  models:
2
- # Qwen2.5-7B-Instruct with HF Inference Provider
3
- - name: "Qwen2.5-7B-Instruct"
4
- provider: "hf-inference"
5
- model_id: "Qwen/Qwen2.5-7B-Instruct"
6
  params:
7
  max_new_tokens: 256
8
  temperature: 0.1
9
  top_p: 0.9
10
- description: "Qwen2.5-7B-Instruct - Instruction-following model for text generation"
 
1
  models:
2
+ # Llama-3.1-8B-Instruct with Nebius Provider
3
+ - name: "Llama-3.1-8B-Instruct"
4
+ provider: "nebius"
5
+ model_id: "meta-llama/Llama-3.1-8B-Instruct"
6
  params:
7
  max_new_tokens: 256
8
  temperature: 0.1
9
  top_p: 0.9
10
+ description: "Llama-3.1-8B-Instruct - Meta's instruction-following model via Nebius"
src/models_registry.py CHANGED
@@ -86,7 +86,23 @@ class HuggingFaceInference:
86
  )
87
 
88
  # Use different methods based on provider capabilities
89
- if provider == "together":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # Together provider uses chat_completion for conversational models
91
  result = client.chat_completion(
92
  messages=[{"role": "user", "content": prompt}],
@@ -184,7 +200,7 @@ class ModelInterface:
184
  return self._generate_mock_sql(model_config, prompt)
185
 
186
  try:
187
- if model_config.provider in ["huggingface", "hf-inference", "together"]:
188
  print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
189
  return self.hf_interface.generate(
190
  model_config.model_id,
 
86
  )
87
 
88
  # Use different methods based on provider capabilities
89
+ if provider == "nebius":
90
+ # Nebius provider uses chat.completions.create
91
+ completion = client.chat.completions.create(
92
+ model=model_id,
93
+ messages=[
94
+ {
95
+ "role": "user",
96
+ "content": prompt
97
+ }
98
+ ],
99
+ max_tokens=params.get('max_new_tokens', 128),
100
+ temperature=params.get('temperature', 0.1),
101
+ top_p=params.get('top_p', 0.9)
102
+ )
103
+ # Extract the content from the response
104
+ return completion.choices[0].message.content
105
+ elif provider == "together":
106
  # Together provider uses chat_completion for conversational models
107
  result = client.chat_completion(
108
  messages=[{"role": "user", "content": prompt}],
 
200
  return self._generate_mock_sql(model_config, prompt)
201
 
202
  try:
203
+ if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]:
204
  print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
205
  return self.hf_interface.generate(
206
  model_config.model_id,