import json
import re
from collections.abc import Sequence
from typing import Union, Optional
import partial_json_parser
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("nemotron_json")
class NemotronJSONToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# Streaming state tracking
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_ids: list[str] = [] # Track IDs for each tool call
# Track what we've sent so far in streaming
self.sent_tool_calls_count: int = 0
self.sent_args_length: dict[int, int] = {} # tool_idx -> length of args sent
self.tool_call_start_token: str = ""
self.tool_call_end_token: str = ""
self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from non-streaming (complete) output."""
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
try:
# Try to extract complete ... blocks
tool_call_matches = self.tool_call_regex.findall(model_output)
if tool_call_matches:
# Complete tool call block found
str_tool_calls = tool_call_matches[0].strip()
else:
# Incomplete - extract everything after
start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token)
str_tool_calls = model_output[start_idx:].strip()
# Ensure array brackets
if not str_tool_calls.startswith("["):
str_tool_calls = "[" + str_tool_calls
if not str_tool_calls.endswith("]"):
str_tool_calls = str_tool_calls + "]"
# Use partial JSON parser for incomplete JSON
json_tool_calls = partial_json_parser.loads(str_tool_calls)
if not isinstance(json_tool_calls, list):
raise ValueError("Tool calls must be a list")
tool_calls = []
for tool_call in json_tool_calls:
if not isinstance(tool_call, dict):
continue
try:
tool_calls.append(ToolCall(
type="function",
function=FunctionCall(
name=tool_call.get("name", ""),
arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \
if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")),
),
))
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
continue
content = model_output[:model_output.find(self.tool_call_start_token)].strip()
return ExtractedToolCallInformation(
tools_called=True if tool_calls else False,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception as e:
logger.exception(f"Error extracting tool calls. Response: {model_output}")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""Extract tool calls from streaming output.
This incrementally parses the JSON as it streams in,
sending delta updates for each tool call and its arguments.
"""
# Check if we just started tool calling
if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text:
# First time seeing , return content before it
content_before = delta_text.split(self.tool_call_start_token)[0]
if content_before:
return DeltaMessage(content=content_before)
# Start of tool call section - no delta yet
return None
# Check if we're not in tool call mode yet
if self.tool_call_start_token not in current_text:
# Regular content, no tool calls
return DeltaMessage(content=delta_text) if delta_text else None
# We're inside ...
# For Nemotron, the entire TOOLCALL block is generated at once
# So we should only parse when we have the complete
# Check if we have the complete tool call block yet
if self.tool_call_end_token not in current_text:
# Incomplete tool call, don't send deltas yet
return None
# We have the complete tool call block, parse it
start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token)
end_idx = current_text.find(self.tool_call_end_token)
json_str = current_text[start_idx:end_idx].strip()
# Parse the complete JSON
try:
# Ensure we have array brackets
if not json_str.startswith("["):
json_str = "[" + json_str
if not json_str.endswith("]"):
json_str = json_str + "]"
# Parse complete JSON
tool_calls_arr = json.loads(json_str)
if not isinstance(tool_calls_arr, list):
return None
# Generate delta updates for new/updated tool calls
delta_tool_calls = []
for idx, tool_call in enumerate(tool_calls_arr):
if not isinstance(tool_call, dict):
continue
# Ensure we have a tool ID for this call
while len(self.tool_call_ids) <= idx:
self.tool_call_ids.append(random_uuid())
tool_id = self.tool_call_ids[idx]
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("arguments", {})
# Convert arguments to JSON string
if isinstance(tool_args, dict):
args_str = json.dumps(tool_args, ensure_ascii=False)
else:
args_str = str(tool_args)
# Check if this is a new tool call
if idx >= self.sent_tool_calls_count:
# New tool call - send ID, name, and complete arguments all at once
# This matches how other models (Llama, etc.) send tool calls
delta_tool_calls.append(DeltaToolCall(
index=idx,
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments=args_str # Send complete JSON string
)
))
self.sent_tool_calls_count = idx + 1
self.sent_args_length[idx] = len(args_str)
# NOTE: We don't send incremental updates for arguments
# because Nemotron generates complete tool calls in one shot
# Unlike thinking models that stream arguments token-by-token
if delta_tool_calls:
return DeltaMessage(tool_calls=delta_tool_calls)
except Exception as e:
# JSON parsing failed (expected for incomplete JSON)
logger.debug(f"Partial JSON parse failed (expected during streaming): {e}")
pass
# Check if we just completed the tool calls (end tag in this delta)
if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text:
# We just completed - reset state for next potential tool call
self.sent_tool_calls_count = 0
self.sent_args_length = {}
self.tool_call_ids = []
return None