Spaces:
Running
Running
| """LLM client for Hugging Face Inference API""" | |
| import os | |
| from typing import Iterator, Optional | |
| from huggingface_hub import InferenceClient | |
| class InferenceProviderClient: | |
| """Client for Hugging Face Inference API""" | |
| def __init__( | |
| self, | |
| model: str = "meta-llama/Llama-3.1-8B-Instruct", | |
| api_key: Optional[str] = None, | |
| temperature: float = 0.3, | |
| max_tokens: int = 800 | |
| ): | |
| """ | |
| Initialize the Inference client | |
| Args: | |
| model: Model identifier (default: Llama-3.1-8B-Instruct) | |
| api_key: HuggingFace API token (defaults to HF_TOKEN env var) | |
| temperature: Sampling temperature (0.0 to 1.0) | |
| max_tokens: Maximum tokens to generate | |
| """ | |
| self.model = model | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| # Get API key from parameter or environment | |
| api_key = api_key or os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("HF_TOKEN environment variable must be set or api_key provided") | |
| # Initialize Hugging Face Inference Client | |
| self.client = InferenceClient(token=api_key) | |
| def generate( | |
| self, | |
| prompt: str, | |
| system_prompt: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None | |
| ) -> str: | |
| """ | |
| Generate a response from the LLM | |
| Args: | |
| prompt: User prompt | |
| system_prompt: Optional system prompt | |
| temperature: Override default temperature | |
| max_tokens: Override default max tokens | |
| Returns: | |
| Generated text response | |
| """ | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| response = self.client.chat_completion( | |
| model=self.model, | |
| messages=messages, | |
| temperature=temperature or self.temperature, | |
| max_tokens=max_tokens or self.max_tokens | |
| ) | |
| return response.choices[0].message.content | |
| def generate_stream( | |
| self, | |
| prompt: str, | |
| system_prompt: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None | |
| ) -> Iterator[str]: | |
| """ | |
| Generate a streaming response from the LLM | |
| Args: | |
| prompt: User prompt | |
| system_prompt: Optional system prompt | |
| temperature: Override default temperature | |
| max_tokens: Override default max tokens | |
| Yields: | |
| Text chunks as they are generated | |
| """ | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| stream = self.client.chat_completion( | |
| model=self.model, | |
| messages=messages, | |
| temperature=temperature or self.temperature, | |
| max_tokens=max_tokens or self.max_tokens, | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| try: | |
| if hasattr(chunk, 'choices') and len(chunk.choices) > 0: | |
| if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): | |
| if chunk.choices[0].delta.content is not None: | |
| yield chunk.choices[0].delta.content | |
| except (IndexError, AttributeError) as e: | |
| # Gracefully handle malformed chunks | |
| continue | |
| def chat( | |
| self, | |
| messages: list[dict], | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| stream: bool = False | |
| ): | |
| """ | |
| Multi-turn chat completion | |
| Args: | |
| messages: List of message dicts with 'role' and 'content' | |
| temperature: Override default temperature | |
| max_tokens: Override default max tokens | |
| stream: Whether to stream the response | |
| Returns: | |
| Response text (or iterator if stream=True) | |
| """ | |
| response = self.client.chat_completion( | |
| model=self.model, | |
| messages=messages, | |
| temperature=temperature or self.temperature, | |
| max_tokens=max_tokens or self.max_tokens, | |
| stream=stream | |
| ) | |
| if stream: | |
| def stream_generator(): | |
| for chunk in response: | |
| try: | |
| if hasattr(chunk, 'choices') and len(chunk.choices) > 0: | |
| if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): | |
| if chunk.choices[0].delta.content is not None: | |
| yield chunk.choices[0].delta.content | |
| except (IndexError, AttributeError): | |
| # Gracefully handle malformed chunks | |
| continue | |
| return stream_generator() | |
| else: | |
| return response.choices[0].message.content | |
| def create_llm_client( | |
| model: str = "meta-llama/Llama-3.1-8B-Instruct", | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000 | |
| ) -> InferenceProviderClient: | |
| """ | |
| Factory function to create and return a configured LLM client | |
| Args: | |
| model: Model identifier | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| Returns: | |
| Configured InferenceProviderClient | |
| """ | |
| return InferenceProviderClient( | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| # Available models (commonly used for OSINT tasks) | |
| AVAILABLE_MODELS = { | |
| "llama-3.1-8b": "meta-llama/Llama-3.1-8B-Instruct", | |
| "llama-3-8b": "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "qwen-32b": "Qwen/Qwen2.5-72B-Instruct", | |
| "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3", | |
| } | |
| def get_model_identifier(model_name: str) -> str: | |
| """Get full model identifier from short name""" | |
| return AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS["llama-3.1-8b"]) | |