Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational

Add Streaming Tool Calling support

#26
nemotron_toolcall_parser_no_streaming.py CHANGED
@@ -44,6 +44,13 @@ class NemotronJSONToolParser(ToolParser):
44
 
45
  self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
46
 
 
 
 
 
 
 
 
47
  def extract_tool_calls(
48
  self,
49
  model_output: str,
@@ -95,6 +102,41 @@ class NemotronJSONToolParser(ToolParser):
95
  tool_calls=[],
96
  content=model_output,
97
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def extract_tool_calls_streaming(
100
  self,
@@ -106,5 +148,65 @@ class NemotronJSONToolParser(ToolParser):
106
  delta_token_ids: Sequence[int],
107
  request: ChatCompletionRequest,
108
  ) -> Union[DeltaMessage, None]:
 
 
 
 
 
 
 
 
 
109
 
110
- raise NotImplementedError("Tool calling is not supported in streaming mode!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
46
 
47
+ self.reasoning_end_token: str = "</think>"
48
+ self.special_token_buffer: str = ""
49
+ self.parsing_special_token: bool = False
50
+ self.parsing_tool_call: bool = False
51
+ self.parsing_reasoning: bool = True
52
+
53
+
54
  def extract_tool_calls(
55
  self,
56
  model_output: str,
 
102
  tool_calls=[],
103
  content=model_output,
104
  )
105
+
106
+
107
+ def _parse_tool_xml(self,xml:str):
108
+ tool_calls = []
109
+ str_tool_calls = self.tool_call_regex.findall(xml)[0].strip()
110
+ if not str_tool_calls.startswith("["):
111
+ str_tool_calls = "[" + str_tool_calls
112
+ if not str_tool_calls.endswith("]"):
113
+ str_tool_calls = "]" + str_tool_calls
114
+ json_tool_calls = json.loads(str_tool_calls)
115
+ for i,tool_call in enumerate(json_tool_calls):
116
+ try:
117
+ tool_calls.append(DeltaToolCall(
118
+ type="function",
119
+ index=i,
120
+ id=f"{tool_call["name"]}-{i}",
121
+ function=DeltaFunctionCall(name=tool_call["name"],arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) if isinstance(tool_call["arguments"], dict) else tool_call["arguments"])
122
+ ))
123
+ except:
124
+ continue
125
+ return tool_calls
126
+
127
+ def _split(self,k:str,v:str):
128
+ idx = v.find(k)
129
+ return v[:idx],v[idx:]
130
+
131
+ def _splitr(self,k:str,v:str):
132
+ idx = v.find(k) + len(k)
133
+ return v[:idx],v[idx:]
134
+
135
+ def _partial_match(self,s1:str,s2:str):
136
+ for i in range(min(len(s1),len(s2))):
137
+ if s1[i] != s2[i]:
138
+ return False
139
+ return True
140
 
141
  def extract_tool_calls_streaming(
142
  self,
 
148
  delta_token_ids: Sequence[int],
149
  request: ChatCompletionRequest,
150
  ) -> Union[DeltaMessage, None]:
151
+
152
+ """
153
+ As a stopgap, this script also implements reasoning parsing until vllm accepts reasoning plugins.
154
+
155
+ Handles streaming tool call detection and parsing for Nemotron format:
156
+ <TOOLCALL>[{"name": "toolName", "arguments": {...}}]</TOOLCALL>
157
+
158
+ This code runs once per token, its not very efficient.
159
+ """
160
 
161
+
162
+ def send_message(content:str):
163
+ if self.parsing_reasoning:
164
+ return DeltaMessage(reasoning_content=content)
165
+ return DeltaMessage(content=content)
166
+
167
+
168
+
169
+ if not self.parsing_special_token:
170
+ if "<" in delta_text:
171
+ self.parsing_special_token = True
172
+ before,after = self._split("<",delta_text)
173
+ self.special_token_buffer += after
174
+ return send_message(before) #DeltaMessage(content=before)
175
+ else:
176
+ return send_message(delta_text) #DeltaMessage(content=delta_text)
177
+ else:
178
+ self.special_token_buffer += delta_text
179
+
180
+ partial_match_tool_token = self._partial_match(self.tool_call_start_token,self.special_token_buffer)
181
+ partial_match_reasoning_token = self._partial_match(self.reasoning_end_token,self.special_token_buffer)
182
+
183
+ if partial_match_tool_token and partial_match_reasoning_token:
184
+ #ambiguous token so we continue for now
185
+ return None
186
+ elif partial_match_tool_token:
187
+ #the above code is responsible for kicking us out if we arent on track so we can just wait till we find a end token
188
+ if self.tool_call_end_token in self.special_token_buffer:
189
+ before,after = self._splitr(self.tool_call_end_token,self.special_token_buffer)
190
+ tool_calls = self._parse_tool_xml(before)
191
+ self.special_token_buffer = ""
192
+ self.parsing_special_token = False
193
+ return DeltaMessage(tool_calls=tool_calls, content=after)
194
+
195
+ elif partial_match_reasoning_token:
196
+ if self.reasoning_end_token in self.special_token_buffer:
197
+ before,after = self._splitr(self.reasoning_end_token,self.special_token_buffer)
198
+ self.special_token_buffer = ""
199
+ self.parsing_reasoning = False
200
+ self.parsing_special_token = False
201
+ return DeltaMessage(reasoning_content=before, content=after)
202
+
203
+ else:
204
+ #neither token matched so its a dud
205
+ content = self.special_token_buffer
206
+ self.special_token_buffer = ""
207
+ self.parsing_tool_call = False
208
+ return DeltaMessage(content=content)
209
+
210
+
211
+
212
+ return None