Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| from typing import Any, AsyncGenerator, Dict, List, Optional | |
| from vanna.core.llm import ( | |
| LlmService, | |
| LlmRequest, | |
| LlmResponse, | |
| LlmStreamChunk, | |
| ) | |
| from vanna.core.tool import ToolCall, ToolSchema | |
| from huggingface_hub import InferenceClient | |
| class VannaHuggingFaceLlmService(LlmService): | |
| def __init__( | |
| self, | |
| model: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| provider: Optional[str] = None, | |
| **extra_client_kwargs: Any, | |
| ) -> None: | |
| """Initialise le client Hugging Face InferenceClient.""" | |
| client_kwargs = extra_client_kwargs.copy() | |
| if model: | |
| client_kwargs["model"] = model | |
| if api_key: | |
| client_kwargs["api_key"] = api_key | |
| if provider: | |
| client_kwargs["provider"] = provider | |
| self.model = model | |
| self._client = InferenceClient(**client_kwargs) | |
| async def send_request(self, request: LlmRequest) -> LlmResponse: | |
| """Send a non-streaming request to OpenAI and return the response.""" | |
| payload = self._build_payload(request) | |
| # Call the API synchronously; this function is async but we can block here. | |
| resp = self._client.chat.completions.create(**payload, stream=False) | |
| if not resp.choices: | |
| return LlmResponse(content=None, tool_calls=None, finish_reason=None) | |
| choice = resp.choices[0] | |
| content: Optional[str] = getattr(choice.message, "content", None) | |
| tool_calls = self._extract_tool_calls_from_message(choice.message) | |
| usage: Dict[str, int] = {} | |
| if getattr(resp, "usage", None): | |
| usage = { | |
| k: int(v) | |
| for k, v in { | |
| "prompt_tokens": getattr(resp.usage, "prompt_tokens", 0), | |
| "completion_tokens": getattr(resp.usage, "completion_tokens", 0), | |
| "total_tokens": getattr(resp.usage, "total_tokens", 0), | |
| }.items() | |
| } | |
| return LlmResponse( | |
| content=content, | |
| tool_calls=tool_calls or None, | |
| finish_reason=getattr(choice, "finish_reason", None), | |
| usage=usage or None, | |
| ) | |
| async def stream_request( | |
| self, request: LlmRequest | |
| ) -> AsyncGenerator[LlmStreamChunk, None]: | |
| """Stream a request to OpenAI. | |
| Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are | |
| accumulated and emitted in a final chunk when the stream ends. | |
| """ | |
| payload = self._build_payload(request) | |
| # Synchronous streaming iterator; iterate within async context. | |
| stream = self._client.chat.completions.create(**payload, stream=True) | |
| # Builders for streamed tool-calls (index -> partial) | |
| tc_builders: Dict[int, Dict[str, Optional[str]]] = {} | |
| last_finish: Optional[str] = None | |
| for event in stream: | |
| if not getattr(event, "choices", None): | |
| continue | |
| choice = event.choices[0] | |
| delta = getattr(choice, "delta", None) | |
| if delta is None: | |
| # Some SDK versions use `event.choices[0].message` on the final packet | |
| last_finish = getattr(choice, "finish_reason", last_finish) | |
| continue | |
| # Text content | |
| content_piece: Optional[str] = getattr(delta, "content", None) | |
| if content_piece: | |
| yield LlmStreamChunk(content=content_piece) | |
| # Tool calls (streamed) | |
| streamed_tool_calls = getattr(delta, "tool_calls", None) | |
| if streamed_tool_calls: | |
| for tc in streamed_tool_calls: | |
| idx = getattr(tc, "index", 0) or 0 | |
| b = tc_builders.setdefault( | |
| idx, {"id": None, "name": None, "arguments": ""} | |
| ) | |
| if getattr(tc, "id", None): | |
| b["id"] = tc.id | |
| fn = getattr(tc, "function", None) | |
| if fn is not None: | |
| if getattr(fn, "name", None): | |
| b["name"] = fn.name | |
| if getattr(fn, "arguments", None): | |
| b["arguments"] = (b["arguments"] or "") + fn.arguments | |
| last_finish = getattr(choice, "finish_reason", last_finish) | |
| # Emit final tool-calls chunk if any | |
| final_tool_calls: List[ToolCall] = [] | |
| for b in tc_builders.values(): | |
| if not b.get("name"): | |
| continue | |
| args_raw = b.get("arguments") or "{}" | |
| try: | |
| loaded = json.loads(args_raw) | |
| if isinstance(loaded, dict): | |
| args_dict: Dict[str, Any] = loaded | |
| else: | |
| args_dict = {"args": loaded} | |
| except Exception: | |
| args_dict = {"_raw": args_raw} | |
| final_tool_calls.append( | |
| ToolCall( | |
| id=b.get("id") or "tool_call", | |
| name=b["name"] or "tool", | |
| arguments=args_dict, | |
| ) | |
| ) | |
| if final_tool_calls: | |
| yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish) | |
| else: | |
| # Still emit a terminal chunk to signal completion | |
| yield LlmStreamChunk(finish_reason=last_finish or "stop") | |
| async def validate_tools(self, tools: List[ToolSchema]) -> List[str]: | |
| """Validate tool schemas. Returns a list of error messages.""" | |
| errors: List[str] = [] | |
| # Basic checks; OpenAI will enforce further validation server-side. | |
| for t in tools: | |
| if not t.name or len(t.name) > 64: | |
| errors.append(f"Invalid tool name: {t.name!r}") | |
| return errors | |
| # Internal helpers | |
| def _build_payload(self, request: LlmRequest) -> Dict[str, Any]: | |
| messages: List[Dict[str, Any]] = [] | |
| # Add system prompt as first message if provided | |
| if request.system_prompt: | |
| messages.append({"role": "system", "content": request.system_prompt}) | |
| for m in request.messages: | |
| msg: Dict[str, Any] = {"role": m.role, "content": m.content} | |
| if m.role == "tool" and m.tool_call_id: | |
| msg["tool_call_id"] = m.tool_call_id | |
| elif m.role == "assistant" and m.tool_calls: | |
| # Convert tool calls to OpenAI format | |
| tool_calls_payload = [] | |
| for tc in m.tool_calls: | |
| tool_calls_payload.append({ | |
| "id": tc.id, | |
| "type": "function", | |
| "function": { | |
| "name": tc.name, | |
| "arguments": json.dumps(tc.arguments) | |
| } | |
| }) | |
| msg["tool_calls"] = tool_calls_payload | |
| messages.append(msg) | |
| tools_payload: Optional[List[Dict[str, Any]]] = None | |
| if request.tools: | |
| tools_payload = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": t.name, | |
| "description": t.description, | |
| "parameters": t.parameters, | |
| }, | |
| } | |
| for t in request.tools | |
| ] | |
| payload: Dict[str, Any] = { | |
| "model": self.model, | |
| "messages": messages, | |
| } | |
| if request.max_tokens is not None: | |
| payload["max_tokens"] = request.max_tokens | |
| if tools_payload: | |
| payload["tools"] = tools_payload | |
| payload["tool_choice"] = "auto" | |
| return payload | |
| def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]: | |
| tool_calls: List[ToolCall] = [] | |
| raw_tool_calls = getattr(message, "tool_calls", None) or [] | |
| for tc in raw_tool_calls: | |
| fn = getattr(tc, "function", None) | |
| if not fn: | |
| continue | |
| args_raw = getattr(fn, "arguments", "{}") | |
| try: | |
| loaded = json.loads(args_raw) | |
| if isinstance(loaded, dict): | |
| args_dict: Dict[str, Any] = loaded | |
| else: | |
| args_dict = {"args": loaded} | |
| except Exception: | |
| args_dict = {"_raw": args_raw} | |
| tool_calls.append( | |
| ToolCall( | |
| id=getattr(tc, "id", "tool_call"), | |
| name=getattr(fn, "name", "tool"), | |
| arguments=args_dict, | |
| ) | |
| ) | |
| return tool_calls |