Spaces:
Sleeping
Sleeping
File size: 6,232 Bytes
6466c00 8c1e2c8 6466c00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
"""LLM client for Hugging Face Inference API"""
import os
from typing import Iterator, Optional
from huggingface_hub import InferenceClient
class InferenceProviderClient:
"""Client for Hugging Face Inference API"""
def __init__(
self,
model: str = "meta-llama/Llama-3.1-8B-Instruct",
api_key: Optional[str] = None,
temperature: float = 0.3,
max_tokens: int = 800
):
"""
Initialize the Inference client
Args:
model: Model identifier (default: Llama-3.1-8B-Instruct)
api_key: HuggingFace API token (defaults to HF_TOKEN env var)
temperature: Sampling temperature (0.0 to 1.0)
max_tokens: Maximum tokens to generate
"""
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
# Get API key from parameter or environment
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable must be set or api_key provided")
# Initialize Hugging Face Inference Client
self.client = InferenceClient(token=api_key)
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> str:
"""
Generate a response from the LLM
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Override default temperature
max_tokens: Override default max tokens
Returns:
Generated text response
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens
)
return response.choices[0].message.content
def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Iterator[str]:
"""
Generate a streaming response from the LLM
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Override default temperature
max_tokens: Override default max tokens
Yields:
Text chunks as they are generated
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
stream = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens,
stream=True
)
for chunk in stream:
try:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except (IndexError, AttributeError) as e:
# Gracefully handle malformed chunks
continue
def chat(
self,
messages: list[dict],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False
):
"""
Multi-turn chat completion
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Override default temperature
max_tokens: Override default max tokens
stream: Whether to stream the response
Returns:
Response text (or iterator if stream=True)
"""
response = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens,
stream=stream
)
if stream:
def stream_generator():
for chunk in response:
try:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except (IndexError, AttributeError):
# Gracefully handle malformed chunks
continue
return stream_generator()
else:
return response.choices[0].message.content
def create_llm_client(
model: str = "meta-llama/Llama-3.1-8B-Instruct",
temperature: float = 0.7,
max_tokens: int = 2000
) -> InferenceProviderClient:
"""
Factory function to create and return a configured LLM client
Args:
model: Model identifier
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Configured InferenceProviderClient
"""
return InferenceProviderClient(
model=model,
temperature=temperature,
max_tokens=max_tokens
)
# Available models (commonly used for OSINT tasks)
AVAILABLE_MODELS = {
"llama-3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
"qwen-32b": "Qwen/Qwen2.5-72B-Instruct",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
}
def get_model_identifier(model_name: str) -> str:
"""Get full model identifier from short name"""
return AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS["llama-3.1-8b"])
|