Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
NVIDIA-Nemotron-Nano-9B-v2 / nemotron_toolcall_parser_streaming.py
ameyasunilm's picture
Upload streaming tool call parser python file for vLLM
1460916 verified
raw
history blame
9.54 kB
import json
import re
from collections.abc import Sequence
from typing import Union, Optional
import partial_json_parser
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("nemotron_json")
class NemotronJSONToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# Streaming state tracking
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_ids: list[str] = [] # Track IDs for each tool call
# Track what we've sent so far in streaming
self.sent_tool_calls_count: int = 0
self.sent_args_length: dict[int, int] = {} # tool_idx -> length of args sent
self.tool_call_start_token: str = "<TOOLCALL>"
self.tool_call_end_token: str = "</TOOLCALL>"
self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from non-streaming (complete) output."""
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
try:
# Try to extract complete <TOOLCALL>...</TOOLCALL> blocks
tool_call_matches = self.tool_call_regex.findall(model_output)
if tool_call_matches:
# Complete tool call block found
str_tool_calls = tool_call_matches[0].strip()
else:
# Incomplete - extract everything after <TOOLCALL>
start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token)
str_tool_calls = model_output[start_idx:].strip()
# Ensure array brackets
if not str_tool_calls.startswith("["):
str_tool_calls = "[" + str_tool_calls
if not str_tool_calls.endswith("]"):
str_tool_calls = str_tool_calls + "]"
# Use partial JSON parser for incomplete JSON
json_tool_calls = partial_json_parser.loads(str_tool_calls)
if not isinstance(json_tool_calls, list):
raise ValueError("Tool calls must be a list")
tool_calls = []
for tool_call in json_tool_calls:
if not isinstance(tool_call, dict):
continue
try:
tool_calls.append(ToolCall(
type="function",
function=FunctionCall(
name=tool_call.get("name", ""),
arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \
if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")),
),
))
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
continue
content = model_output[:model_output.find(self.tool_call_start_token)].strip()
return ExtractedToolCallInformation(
tools_called=True if tool_calls else False,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception as e:
logger.exception(f"Error extracting tool calls. Response: {model_output}")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""Extract tool calls from streaming output.
This incrementally parses the <TOOLCALL> JSON as it streams in,
sending delta updates for each tool call and its arguments.
"""
# Check if we just started tool calling
if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text:
# First time seeing <TOOLCALL>, return content before it
content_before = delta_text.split(self.tool_call_start_token)[0]
if content_before:
return DeltaMessage(content=content_before)
# Start of tool call section - no delta yet
return None
# Check if we're not in tool call mode yet
if self.tool_call_start_token not in current_text:
# Regular content, no tool calls
return DeltaMessage(content=delta_text) if delta_text else None
# We're inside <TOOLCALL>...</TOOLCALL>
# For Nemotron, the entire TOOLCALL block is generated at once
# So we should only parse when we have the complete </TOOLCALL>
# Check if we have the complete tool call block yet
if self.tool_call_end_token not in current_text:
# Incomplete tool call, don't send deltas yet
return None
# We have the complete tool call block, parse it
start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token)
end_idx = current_text.find(self.tool_call_end_token)
json_str = current_text[start_idx:end_idx].strip()
# Parse the complete JSON
try:
# Ensure we have array brackets
if not json_str.startswith("["):
json_str = "[" + json_str
if not json_str.endswith("]"):
json_str = json_str + "]"
# Parse complete JSON
tool_calls_arr = json.loads(json_str)
if not isinstance(tool_calls_arr, list):
return None
# Generate delta updates for new/updated tool calls
delta_tool_calls = []
for idx, tool_call in enumerate(tool_calls_arr):
if not isinstance(tool_call, dict):
continue
# Ensure we have a tool ID for this call
while len(self.tool_call_ids) <= idx:
self.tool_call_ids.append(random_uuid())
tool_id = self.tool_call_ids[idx]
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("arguments", {})
# Convert arguments to JSON string
if isinstance(tool_args, dict):
args_str = json.dumps(tool_args, ensure_ascii=False)
else:
args_str = str(tool_args)
# Check if this is a new tool call
if idx >= self.sent_tool_calls_count:
# New tool call - send ID, name, and complete arguments all at once
# This matches how other models (Llama, etc.) send tool calls
delta_tool_calls.append(DeltaToolCall(
index=idx,
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments=args_str # Send complete JSON string
)
))
self.sent_tool_calls_count = idx + 1
self.sent_args_length[idx] = len(args_str)
# NOTE: We don't send incremental updates for arguments
# because Nemotron generates complete tool calls in one shot
# Unlike thinking models that stream arguments token-by-token
if delta_tool_calls:
return DeltaMessage(tool_calls=delta_tool_calls)
except Exception as e:
# JSON parsing failed (expected for incomplete JSON)
logger.debug(f"Partial JSON parse failed (expected during streaming): {e}")
pass
# Check if we just completed the tool calls (end tag in this delta)
if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text:
# We just completed - reset state for next potential tool call
self.sent_tool_calls_count = 0
self.sent_args_length = {}
self.tool_call_ids = []
return None