graphics-llm / src /llm_client.py
Tom
Update to Jina-CLIP-v2 embeddings and rebrand to Viz LLM
721d500
raw
history blame
6.23 kB
"""LLM client for Hugging Face Inference API"""
import os
from typing import Iterator, Optional
from huggingface_hub import InferenceClient
class InferenceProviderClient:
"""Client for Hugging Face Inference API"""
def __init__(
self,
model: str = "meta-llama/Llama-3.1-8B-Instruct",
api_key: Optional[str] = None,
temperature: float = 0.3,
max_tokens: int = 800
):
"""
Initialize the Inference client
Args:
model: Model identifier (default: Llama-3.1-8B-Instruct)
api_key: HuggingFace API token (defaults to HF_TOKEN env var)
temperature: Sampling temperature (0.0 to 1.0)
max_tokens: Maximum tokens to generate
"""
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
# Get API key from parameter or environment
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable must be set or api_key provided")
# Initialize Hugging Face Inference Client
self.client = InferenceClient(token=api_key)
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> str:
"""
Generate a response from the LLM
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Override default temperature
max_tokens: Override default max tokens
Returns:
Generated text response
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens
)
return response.choices[0].message.content
def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Iterator[str]:
"""
Generate a streaming response from the LLM
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Override default temperature
max_tokens: Override default max tokens
Yields:
Text chunks as they are generated
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
stream = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens,
stream=True
)
for chunk in stream:
try:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except (IndexError, AttributeError) as e:
# Gracefully handle malformed chunks
continue
def chat(
self,
messages: list[dict],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False
):
"""
Multi-turn chat completion
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Override default temperature
max_tokens: Override default max tokens
stream: Whether to stream the response
Returns:
Response text (or iterator if stream=True)
"""
response = self.client.chat_completion(
model=self.model,
messages=messages,
temperature=temperature or self.temperature,
max_tokens=max_tokens or self.max_tokens,
stream=stream
)
if stream:
def stream_generator():
for chunk in response:
try:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except (IndexError, AttributeError):
# Gracefully handle malformed chunks
continue
return stream_generator()
else:
return response.choices[0].message.content
def create_llm_client(
model: str = "meta-llama/Llama-3.1-8B-Instruct",
temperature: float = 0.7,
max_tokens: int = 2000
) -> InferenceProviderClient:
"""
Factory function to create and return a configured LLM client
Args:
model: Model identifier
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Configured InferenceProviderClient
"""
return InferenceProviderClient(
model=model,
temperature=temperature,
max_tokens=max_tokens
)
# Available models (commonly used for OSINT tasks)
AVAILABLE_MODELS = {
"llama-3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
"qwen-32b": "Qwen/Qwen2.5-72B-Instruct",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
}
def get_model_identifier(model_name: str) -> str:
"""Get full model identifier from short name"""
return AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS["llama-3.1-8b"])