|
|
import json |
|
|
import re |
|
|
from collections.abc import Sequence |
|
|
from typing import Union, Optional |
|
|
|
|
|
import partial_json_parser |
|
|
|
|
|
from vllm.entrypoints.openai.protocol import ( |
|
|
ChatCompletionRequest, |
|
|
DeltaFunctionCall, |
|
|
DeltaMessage, |
|
|
DeltaToolCall, |
|
|
ExtractedToolCallInformation, |
|
|
FunctionCall, |
|
|
ToolCall, |
|
|
) |
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
|
|
ToolParser, |
|
|
ToolParserManager, |
|
|
) |
|
|
from vllm.logger import init_logger |
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer |
|
|
from vllm.utils import random_uuid |
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
|
|
|
@ToolParserManager.register_module("nemotron_json") |
|
|
class NemotronJSONToolParser(ToolParser): |
|
|
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
|
super().__init__(tokenizer) |
|
|
|
|
|
|
|
|
self.current_tool_name_sent: bool = False |
|
|
self.prev_tool_call_arr: list[dict] = [] |
|
|
self.current_tool_id: int = -1 |
|
|
self.streamed_args_for_tool: list[str] = [] |
|
|
self.tool_call_ids: list[str] = [] |
|
|
|
|
|
|
|
|
self.sent_tool_calls_count: int = 0 |
|
|
self.sent_args_length: dict[int, int] = {} |
|
|
|
|
|
self.tool_call_start_token: str = "<TOOLCALL>" |
|
|
self.tool_call_end_token: str = "</TOOLCALL>" |
|
|
|
|
|
self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
|
|
|
|
|
def extract_tool_calls( |
|
|
self, |
|
|
model_output: str, |
|
|
request: ChatCompletionRequest, |
|
|
) -> ExtractedToolCallInformation: |
|
|
"""Extract tool calls from non-streaming (complete) output.""" |
|
|
|
|
|
if self.tool_call_start_token not in model_output: |
|
|
return ExtractedToolCallInformation( |
|
|
tools_called=False, |
|
|
tool_calls=[], |
|
|
content=model_output, |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
tool_call_matches = self.tool_call_regex.findall(model_output) |
|
|
|
|
|
if tool_call_matches: |
|
|
|
|
|
str_tool_calls = tool_call_matches[0].strip() |
|
|
else: |
|
|
|
|
|
start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token) |
|
|
str_tool_calls = model_output[start_idx:].strip() |
|
|
|
|
|
|
|
|
if not str_tool_calls.startswith("["): |
|
|
str_tool_calls = "[" + str_tool_calls |
|
|
if not str_tool_calls.endswith("]"): |
|
|
str_tool_calls = str_tool_calls + "]" |
|
|
|
|
|
|
|
|
json_tool_calls = partial_json_parser.loads(str_tool_calls) |
|
|
|
|
|
if not isinstance(json_tool_calls, list): |
|
|
raise ValueError("Tool calls must be a list") |
|
|
|
|
|
tool_calls = [] |
|
|
|
|
|
for tool_call in json_tool_calls: |
|
|
if not isinstance(tool_call, dict): |
|
|
continue |
|
|
try: |
|
|
tool_calls.append(ToolCall( |
|
|
type="function", |
|
|
function=FunctionCall( |
|
|
name=tool_call.get("name", ""), |
|
|
arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \ |
|
|
if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")), |
|
|
), |
|
|
)) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to parse tool call: {e}") |
|
|
continue |
|
|
|
|
|
content = model_output[:model_output.find(self.tool_call_start_token)].strip() |
|
|
|
|
|
return ExtractedToolCallInformation( |
|
|
tools_called=True if tool_calls else False, |
|
|
tool_calls=tool_calls, |
|
|
content=content if content else None, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception(f"Error extracting tool calls. Response: {model_output}") |
|
|
return ExtractedToolCallInformation( |
|
|
tools_called=False, |
|
|
tool_calls=[], |
|
|
content=model_output, |
|
|
) |
|
|
|
|
|
def extract_tool_calls_streaming( |
|
|
self, |
|
|
previous_text: str, |
|
|
current_text: str, |
|
|
delta_text: str, |
|
|
previous_token_ids: Sequence[int], |
|
|
current_token_ids: Sequence[int], |
|
|
delta_token_ids: Sequence[int], |
|
|
request: ChatCompletionRequest, |
|
|
) -> Union[DeltaMessage, None]: |
|
|
"""Extract tool calls from streaming output. |
|
|
|
|
|
This incrementally parses the <TOOLCALL> JSON as it streams in, |
|
|
sending delta updates for each tool call and its arguments. |
|
|
""" |
|
|
|
|
|
|
|
|
if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text: |
|
|
|
|
|
content_before = delta_text.split(self.tool_call_start_token)[0] |
|
|
if content_before: |
|
|
return DeltaMessage(content=content_before) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.tool_call_start_token not in current_text: |
|
|
|
|
|
return DeltaMessage(content=delta_text) if delta_text else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.tool_call_end_token not in current_text: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token) |
|
|
end_idx = current_text.find(self.tool_call_end_token) |
|
|
json_str = current_text[start_idx:end_idx].strip() |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if not json_str.startswith("["): |
|
|
json_str = "[" + json_str |
|
|
if not json_str.endswith("]"): |
|
|
json_str = json_str + "]" |
|
|
|
|
|
|
|
|
tool_calls_arr = json.loads(json_str) |
|
|
|
|
|
if not isinstance(tool_calls_arr, list): |
|
|
return None |
|
|
|
|
|
|
|
|
delta_tool_calls = [] |
|
|
|
|
|
for idx, tool_call in enumerate(tool_calls_arr): |
|
|
if not isinstance(tool_call, dict): |
|
|
continue |
|
|
|
|
|
|
|
|
while len(self.tool_call_ids) <= idx: |
|
|
self.tool_call_ids.append(random_uuid()) |
|
|
|
|
|
tool_id = self.tool_call_ids[idx] |
|
|
tool_name = tool_call.get("name", "") |
|
|
tool_args = tool_call.get("arguments", {}) |
|
|
|
|
|
|
|
|
if isinstance(tool_args, dict): |
|
|
args_str = json.dumps(tool_args, ensure_ascii=False) |
|
|
else: |
|
|
args_str = str(tool_args) |
|
|
|
|
|
|
|
|
if idx >= self.sent_tool_calls_count: |
|
|
|
|
|
|
|
|
delta_tool_calls.append(DeltaToolCall( |
|
|
index=idx, |
|
|
id=tool_id, |
|
|
type="function", |
|
|
function=DeltaFunctionCall( |
|
|
name=tool_name, |
|
|
arguments=args_str |
|
|
) |
|
|
)) |
|
|
self.sent_tool_calls_count = idx + 1 |
|
|
self.sent_args_length[idx] = len(args_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if delta_tool_calls: |
|
|
return DeltaMessage(tool_calls=delta_tool_calls) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.debug(f"Partial JSON parse failed (expected during streaming): {e}") |
|
|
pass |
|
|
|
|
|
|
|
|
if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text: |
|
|
|
|
|
self.sent_tool_calls_count = 0 |
|
|
self.sent_args_length = {} |
|
|
self.tool_call_ids = [] |
|
|
|
|
|
return None |
|
|
|