Maheen001 commited on
Commit
bf47268
·
verified ·
1 Parent(s): 1e13436

Create utils/llm_utils.py

Browse files
Files changed (1) hide show
  1. utils/llm_utils.py +159 -0
utils/llm_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ import asyncio
4
+
5
+
6
+ _llm_config = {
7
+ 'provider': None,
8
+ 'model': None
9
+ }
10
+
11
+
12
+ def setup_llm_fallback():
13
+ """Setup LLM provider fallback chain"""
14
+ # Try OpenAI first
15
+ if os.getenv('OPENAI_API_KEY'):
16
+ _llm_config['provider'] = 'openai'
17
+ _llm_config['model'] = 'gpt-4o-mini'
18
+ return
19
+
20
+ # Fallback to Groq
21
+ if os.getenv('GROQ_API_KEY'):
22
+ _llm_config['provider'] = 'groq'
23
+ _llm_config['model'] = 'llama-3.3-70b-versatile'
24
+ return
25
+
26
+ # Fallback to Hyperbolic
27
+ if os.getenv('HYPERBOLIC_API_KEY'):
28
+ _llm_config['provider'] = 'hyperbolic'
29
+ _llm_config['model'] = 'meta-llama/Llama-3.3-70B-Instruct'
30
+ return
31
+
32
+ # Last resort: Hugging Face Inference API
33
+ if os.getenv('HF_TOKEN'):
34
+ _llm_config['provider'] = 'huggingface'
35
+ _llm_config['model'] = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
36
+ return
37
+
38
+ raise ValueError("No LLM API keys configured. Please set at least one of: OPENAI_API_KEY, GROQ_API_KEY, HYPERBOLIC_API_KEY, HF_TOKEN")
39
+
40
+
41
+ async def get_llm_response(
42
+ prompt: str,
43
+ temperature: float = 0.7,
44
+ max_tokens: int = 2000
45
+ ) -> str:
46
+ """
47
+ Get LLM response using fallback chain
48
+
49
+ Args:
50
+ prompt: Input prompt
51
+ temperature: Sampling temperature
52
+ max_tokens: Maximum tokens to generate
53
+
54
+ Returns:
55
+ LLM response text
56
+ """
57
+ provider = _llm_config.get('provider')
58
+ model = _llm_config.get('model')
59
+
60
+ if not provider:
61
+ setup_llm_fallback()
62
+ provider = _llm_config.get('provider')
63
+ model = _llm_config.get('model')
64
+
65
+ try:
66
+ if provider == 'openai':
67
+ return await _call_openai(prompt, model, temperature, max_tokens)
68
+ elif provider == 'groq':
69
+ return await _call_groq(prompt, model, temperature, max_tokens)
70
+ elif provider == 'hyperbolic':
71
+ return await _call_hyperbolic(prompt, model, temperature, max_tokens)
72
+ elif provider == 'huggingface':
73
+ return await _call_huggingface(prompt, model, temperature, max_tokens)
74
+ except Exception as e:
75
+ print(f"Error with {provider}: {e}")
76
+ # Try next provider in chain
77
+ if provider == 'openai' and os.getenv('GROQ_API_KEY'):
78
+ _llm_config['provider'] = 'groq'
79
+ return await get_llm_response(prompt, temperature, max_tokens)
80
+ raise
81
+
82
+
83
+ async def _call_openai(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
84
+ """Call OpenAI API"""
85
+ from openai import AsyncOpenAI
86
+
87
+ client = AsyncOpenAI(api_key=os.getenv('OPENAI_API_KEY'))
88
+
89
+ response = await client.chat.completions.create(
90
+ model=model,
91
+ messages=[{'role': 'user', 'content': prompt}],
92
+ temperature=temperature,
93
+ max_tokens=max_tokens
94
+ )
95
+
96
+ return response.choices[0].message.content
97
+
98
+
99
+ async def _call_groq(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
100
+ """Call Groq API"""
101
+ from groq import AsyncGroq
102
+
103
+ client = AsyncGroq(api_key=os.getenv('GROQ_API_KEY'))
104
+
105
+ response = await client.chat.completions.create(
106
+ model=model,
107
+ messages=[{'role': 'user', 'content': prompt}],
108
+ temperature=temperature,
109
+ max_tokens=max_tokens
110
+ )
111
+
112
+ return response.choices[0].message.content
113
+
114
+
115
+ async def _call_hyperbolic(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
116
+ """Call Hyperbolic API"""
117
+ import aiohttp
118
+
119
+ url = "https://api.hyperbolic.xyz/v1/chat/completions"
120
+ headers = {
121
+ "Content-Type": "application/json",
122
+ "Authorization": f"Bearer {os.getenv('HYPERBOLIC_API_KEY')}"
123
+ }
124
+
125
+ data = {
126
+ "model": model,
127
+ "messages": [{"role": "user", "content": prompt}],
128
+ "temperature": temperature,
129
+ "max_tokens": max_tokens
130
+ }
131
+
132
+ async with aiohttp.ClientSession() as session:
133
+ async with session.post(url, headers=headers, json=data) as response:
134
+ result = await response.json()
135
+ return result['choices'][0]['message']['content']
136
+
137
+
138
+ async def _call_huggingface(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
139
+ """Call Hugging Face Inference API"""
140
+ import aiohttp
141
+
142
+ url = f"https://api-inference.huggingface.co/models/{model}"
143
+ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
144
+
145
+ data = {
146
+ "inputs": prompt,
147
+ "parameters": {
148
+ "temperature": temperature,
149
+ "max_new_tokens": max_tokens,
150
+ "return_full_text": False
151
+ }
152
+ }
153
+
154
+ async with aiohttp.ClientSession() as session:
155
+ async with session.post(url, headers=headers, json=data) as response:
156
+ result = await response.json()
157
+ if isinstance(result, list) and len(result) > 0:
158
+ return result[0].get('generated_text', '')
159
+ return str(result)