graphics-llm / src /vanna_huggingface_llm_service.py
remdms's picture
Add Vanna
9db289b
raw
history blame
8.79 kB
from __future__ import annotations
import json
from typing import Any, AsyncGenerator, Dict, List, Optional
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
from huggingface_hub import InferenceClient
class VannaHuggingFaceLlmService(LlmService):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
provider: Optional[str] = None,
**extra_client_kwargs: Any,
) -> None:
"""Initialise le client Hugging Face InferenceClient."""
client_kwargs = extra_client_kwargs.copy()
if model:
client_kwargs["model"] = model
if api_key:
client_kwargs["api_key"] = api_key
if provider:
client_kwargs["provider"] = provider
self.model = model
self._client = InferenceClient(**client_kwargs)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to OpenAI and return the response."""
payload = self._build_payload(request)
# Call the API synchronously; this function is async but we can block here.
resp = self._client.chat.completions.create(**payload, stream=False)
if not resp.choices:
return LlmResponse(content=None, tool_calls=None, finish_reason=None)
choice = resp.choices[0]
content: Optional[str] = getattr(choice.message, "content", None)
tool_calls = self._extract_tool_calls_from_message(choice.message)
usage: Dict[str, int] = {}
if getattr(resp, "usage", None):
usage = {
k: int(v)
for k, v in {
"prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
"completion_tokens": getattr(resp.usage, "completion_tokens", 0),
"total_tokens": getattr(resp.usage, "total_tokens", 0),
}.items()
}
return LlmResponse(
content=content,
tool_calls=tool_calls or None,
finish_reason=getattr(choice, "finish_reason", None),
usage=usage or None,
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to OpenAI.
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
accumulated and emitted in a final chunk when the stream ends.
"""
payload = self._build_payload(request)
# Synchronous streaming iterator; iterate within async context.
stream = self._client.chat.completions.create(**payload, stream=True)
# Builders for streamed tool-calls (index -> partial)
tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
last_finish: Optional[str] = None
for event in stream:
if not getattr(event, "choices", None):
continue
choice = event.choices[0]
delta = getattr(choice, "delta", None)
if delta is None:
# Some SDK versions use `event.choices[0].message` on the final packet
last_finish = getattr(choice, "finish_reason", last_finish)
continue
# Text content
content_piece: Optional[str] = getattr(delta, "content", None)
if content_piece:
yield LlmStreamChunk(content=content_piece)
# Tool calls (streamed)
streamed_tool_calls = getattr(delta, "tool_calls", None)
if streamed_tool_calls:
for tc in streamed_tool_calls:
idx = getattr(tc, "index", 0) or 0
b = tc_builders.setdefault(
idx, {"id": None, "name": None, "arguments": ""}
)
if getattr(tc, "id", None):
b["id"] = tc.id
fn = getattr(tc, "function", None)
if fn is not None:
if getattr(fn, "name", None):
b["name"] = fn.name
if getattr(fn, "arguments", None):
b["arguments"] = (b["arguments"] or "") + fn.arguments
last_finish = getattr(choice, "finish_reason", last_finish)
# Emit final tool-calls chunk if any
final_tool_calls: List[ToolCall] = []
for b in tc_builders.values():
if not b.get("name"):
continue
args_raw = b.get("arguments") or "{}"
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
final_tool_calls.append(
ToolCall(
id=b.get("id") or "tool_call",
name=b["name"] or "tool",
arguments=args_dict,
)
)
if final_tool_calls:
yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
else:
# Still emit a terminal chunk to signal completion
yield LlmStreamChunk(finish_reason=last_finish or "stop")
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tool schemas. Returns a list of error messages."""
errors: List[str] = []
# Basic checks; OpenAI will enforce further validation server-side.
for t in tools:
if not t.name or len(t.name) > 64:
errors.append(f"Invalid tool name: {t.name!r}")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
messages: List[Dict[str, Any]] = []
# Add system prompt as first message if provided
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
for m in request.messages:
msg: Dict[str, Any] = {"role": m.role, "content": m.content}
if m.role == "tool" and m.tool_call_id:
msg["tool_call_id"] = m.tool_call_id
elif m.role == "assistant" and m.tool_calls:
# Convert tool calls to OpenAI format
tool_calls_payload = []
for tc in m.tool_calls:
tool_calls_payload.append({
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments)
}
})
msg["tool_calls"] = tool_calls_payload
messages.append(msg)
tools_payload: Optional[List[Dict[str, Any]]] = None
if request.tools:
tools_payload = [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
for t in request.tools
]
payload: Dict[str, Any] = {
"model": self.model,
"messages": messages,
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if tools_payload:
payload["tools"] = tools_payload
payload["tool_choice"] = "auto"
return payload
def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
tool_calls: List[ToolCall] = []
raw_tool_calls = getattr(message, "tool_calls", None) or []
for tc in raw_tool_calls:
fn = getattr(tc, "function", None)
if not fn:
continue
args_raw = getattr(fn, "arguments", "{}")
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
tool_calls.append(
ToolCall(
id=getattr(tc, "id", "tool_call"),
name=getattr(fn, "name", "tool"),
arguments=args_dict,
)
)
return tool_calls