import json import re from collections.abc import Sequence from typing import Union, Optional import partial_json_parser from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @ToolParserManager.register_module("nemotron_json") class NemotronJSONToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) # Streaming state tracking self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_ids: list[str] = [] # Track IDs for each tool call # Track what we've sent so far in streaming self.sent_tool_calls_count: int = 0 self.sent_args_length: dict[int, int] = {} # tool_idx -> length of args sent self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """Extract tool calls from non-streaming (complete) output.""" if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) try: # Try to extract complete ... blocks tool_call_matches = self.tool_call_regex.findall(model_output) if tool_call_matches: # Complete tool call block found str_tool_calls = tool_call_matches[0].strip() else: # Incomplete - extract everything after start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token) str_tool_calls = model_output[start_idx:].strip() # Ensure array brackets if not str_tool_calls.startswith("["): str_tool_calls = "[" + str_tool_calls if not str_tool_calls.endswith("]"): str_tool_calls = str_tool_calls + "]" # Use partial JSON parser for incomplete JSON json_tool_calls = partial_json_parser.loads(str_tool_calls) if not isinstance(json_tool_calls, list): raise ValueError("Tool calls must be a list") tool_calls = [] for tool_call in json_tool_calls: if not isinstance(tool_call, dict): continue try: tool_calls.append(ToolCall( type="function", function=FunctionCall( name=tool_call.get("name", ""), arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \ if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")), ), )) except Exception as e: logger.warning(f"Failed to parse tool call: {e}") continue content = model_output[:model_output.find(self.tool_call_start_token)].strip() return ExtractedToolCallInformation( tools_called=True if tool_calls else False, tool_calls=tool_calls, content=content if content else None, ) except Exception as e: logger.exception(f"Error extracting tool calls. Response: {model_output}") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """Extract tool calls from streaming output. This incrementally parses the JSON as it streams in, sending delta updates for each tool call and its arguments. """ # Check if we just started tool calling if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text: # First time seeing , return content before it content_before = delta_text.split(self.tool_call_start_token)[0] if content_before: return DeltaMessage(content=content_before) # Start of tool call section - no delta yet return None # Check if we're not in tool call mode yet if self.tool_call_start_token not in current_text: # Regular content, no tool calls return DeltaMessage(content=delta_text) if delta_text else None # We're inside ... # For Nemotron, the entire TOOLCALL block is generated at once # So we should only parse when we have the complete # Check if we have the complete tool call block yet if self.tool_call_end_token not in current_text: # Incomplete tool call, don't send deltas yet return None # We have the complete tool call block, parse it start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token) end_idx = current_text.find(self.tool_call_end_token) json_str = current_text[start_idx:end_idx].strip() # Parse the complete JSON try: # Ensure we have array brackets if not json_str.startswith("["): json_str = "[" + json_str if not json_str.endswith("]"): json_str = json_str + "]" # Parse complete JSON tool_calls_arr = json.loads(json_str) if not isinstance(tool_calls_arr, list): return None # Generate delta updates for new/updated tool calls delta_tool_calls = [] for idx, tool_call in enumerate(tool_calls_arr): if not isinstance(tool_call, dict): continue # Ensure we have a tool ID for this call while len(self.tool_call_ids) <= idx: self.tool_call_ids.append(random_uuid()) tool_id = self.tool_call_ids[idx] tool_name = tool_call.get("name", "") tool_args = tool_call.get("arguments", {}) # Convert arguments to JSON string if isinstance(tool_args, dict): args_str = json.dumps(tool_args, ensure_ascii=False) else: args_str = str(tool_args) # Check if this is a new tool call if idx >= self.sent_tool_calls_count: # New tool call - send ID, name, and complete arguments all at once # This matches how other models (Llama, etc.) send tool calls delta_tool_calls.append(DeltaToolCall( index=idx, id=tool_id, type="function", function=DeltaFunctionCall( name=tool_name, arguments=args_str # Send complete JSON string ) )) self.sent_tool_calls_count = idx + 1 self.sent_args_length[idx] = len(args_str) # NOTE: We don't send incremental updates for arguments # because Nemotron generates complete tool calls in one shot # Unlike thinking models that stream arguments token-by-token if delta_tool_calls: return DeltaMessage(tool_calls=delta_tool_calls) except Exception as e: # JSON parsing failed (expected for incomplete JSON) logger.debug(f"Partial JSON parse failed (expected during streaming): {e}") pass # Check if we just completed the tool calls (end tag in this delta) if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text: # We just completed - reset state for next potential tool call self.sent_tool_calls_count = 0 self.sent_args_length = {} self.tool_call_ids = [] return None