Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
README.md CHANGED
@@ -55,13 +55,11 @@ The supported languages include: English, German, Spanish, French, Italian, and
55
 
56
  This model is ready for commercial use.
57
 
58
- ## Feature Voting
59
-
60
- We want to hear from you! Share your ideas, vote on what matters, and help [shape the future of Nemotron](https://nemotron.ideas.nvidia.com/).
61
 
62
  ## License/Terms of Use
63
 
64
- Governing Terms: Use of this model is governed by the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
 
65
 
66
  ## Evaluation Results
67
 
 
55
 
56
  This model is ready for commercial use.
57
 
 
 
 
58
 
59
  ## License/Terms of Use
60
 
61
+ GOVERNING TERMS: This trial service is governed by the [NVIDIA API Trial Terms of Service](https://assets.ngc.nvidia.com/products/api-catalog/legal/NVIDIA%20API%20Trial%20Terms%20of%20Service.pdf). Use of this model is governed by the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
62
+
63
 
64
  ## Evaluation Results
65
 
modeling_nemotron_h.py CHANGED
@@ -1112,7 +1112,6 @@ class NemotronHPreTrainedModel(PreTrainedModel):
1112
  _no_split_modules = ["NemotronHBlock"]
1113
  supports_gradient_checkpointing = True
1114
  _is_stateful = True
1115
- _supports_flash_attn_2 = True
1116
 
1117
  def _init_weights(self, module):
1118
  """Initialize the weights."""
 
1112
  _no_split_modules = ["NemotronHBlock"]
1113
  supports_gradient_checkpointing = True
1114
  _is_stateful = True
 
1115
 
1116
  def _init_weights(self, module):
1117
  """Initialize the weights."""
nemotron_toolcall_parser_streaming.py DELETED
@@ -1,480 +0,0 @@
1
- import json
2
- from collections.abc import Sequence
3
- from random import choices
4
- from string import ascii_letters, digits
5
- from typing import Optional, Union
6
-
7
- import partial_json_parser
8
- import regex as re
9
- from partial_json_parser.core.options import Allow
10
- from pydantic import Field
11
-
12
- from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
13
- DeltaFunctionCall, DeltaMessage,
14
- DeltaToolCall,
15
- ExtractedToolCallInformation,
16
- FunctionCall, ToolCall)
17
- from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
- ToolParser, ToolParserManager)
19
- from vllm.logger import init_logger
20
- from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
21
-
22
- logger = init_logger(__name__)
23
-
24
- ALPHANUMERIC = ascii_letters + digits
25
-
26
-
27
- class NemotronToolCall(ToolCall):
28
- id: str = Field(
29
- default_factory=lambda: NemotronToolCall.generate_random_id())
30
-
31
- @staticmethod
32
- def generate_random_id():
33
- return "".join(choices(ALPHANUMERIC, k=9))
34
-
35
- @staticmethod
36
- def is_valid_id(id: str) -> bool:
37
- return id.isalnum() and len(id) == 9
38
-
39
-
40
- def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
41
- return isinstance(model_tokenizer, MistralTokenizer) \
42
- and model_tokenizer.version >= 11
43
-
44
-
45
- @ToolParserManager.register_module("nemotron_json")
46
- class NemotronToolParser(ToolParser):
47
- """
48
- Tool call parser for Nemotron-Nano-V2
49
-
50
- Used when --enable-auto-tool-choice --tool-call-parser nemotron_json are all set
51
- """
52
-
53
- def __init__(self, tokenizer: AnyTokenizer):
54
- super().__init__(tokenizer)
55
- # initialize properties used for state when parsing tool calls in
56
- # streaming mode
57
- self.prev_tool_call_arr: list[dict] = []
58
- self.current_tool_id: int = -1
59
- self.current_tool_name_sent: bool = False
60
- self.streamed_args_for_tool: list[str] = [
61
- ] # map what has been streamed for each tool so far to a list
62
- self.tool_args_emitted: list[bool] = []
63
- self.bot_token = "<TOOLCALL>"
64
- self.bot_token_id = self.vocab.get(self.bot_token)
65
- logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}")
66
- self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
67
- if _is_fn_name_regex_support(self.model_tokenizer):
68
- self.fn_name_regex = re.compile(
69
- r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
70
- else:
71
- self.fn_name_regex = None
72
-
73
- # Buffer for partial tag sequences to disambiguate between normal content and
74
- # a forthcoming <TOOLCALL> or </TOOLCALL> tag in streaming.
75
- self._pending_tag_buffer: str = ""
76
-
77
- @staticmethod
78
- def _strip_trailing_auto_closers(chunk: str) -> str:
79
- """
80
- Remove parser auto-completed closing braces/brackets plus trailing whitespace.
81
- These should be flushed only when a tool call completes to avoid duplicate
82
- argument fragments.
83
- """
84
- idx = len(chunk)
85
- while idx > 0 and chunk[idx - 1] in " \t\r\n}]":
86
- idx -= 1
87
- # Remove trailing non-escaped double quotes (partial JSON auto-closes strings)
88
- while idx > 0 and chunk[idx - 1] == '"':
89
- # keep escaped quotes (\"), only strip bare ones
90
- if idx - 2 >= 0 and chunk[idx - 2] == '\\':
91
- break
92
- idx -= 1
93
- return chunk[:idx]
94
-
95
- @staticmethod
96
- def _common_prefix_len(left: str, right: str) -> int:
97
- """
98
- Return the length of the shared prefix between left and right strings.
99
- """
100
- max_len = min(len(left), len(right))
101
- idx = 0
102
- while idx < max_len and left[idx] == right[idx]:
103
- idx += 1
104
- return idx
105
-
106
- def _compute_arguments_delta(self, cur_arguments_json: str,
107
- end_of_call: bool) -> str:
108
- """
109
- Determine the incremental suffix to stream for the current tool call.
110
- Ensures we only emit monotonic chunks by trimming our tracked prefix to
111
- the longest common prefix with the latest JSON snapshot.
112
- """
113
- tool_idx = self.current_tool_id
114
- if tool_idx < 0 or tool_idx >= len(self.streamed_args_for_tool):
115
- return ""
116
-
117
- streamed_prefix = self.streamed_args_for_tool[tool_idx]
118
- had_any = (self.tool_args_emitted[tool_idx]
119
- if tool_idx < len(self.tool_args_emitted) else False)
120
-
121
- lcp_len = self._common_prefix_len(cur_arguments_json,
122
- streamed_prefix)
123
- if lcp_len != len(streamed_prefix):
124
- streamed_prefix = streamed_prefix[:lcp_len]
125
- self.streamed_args_for_tool[tool_idx] = streamed_prefix
126
-
127
- if (not had_any and not end_of_call and lcp_len == 0
128
- and cur_arguments_json.endswith('": ""}')
129
- and '": ""' in cur_arguments_json):
130
- closing_pos = cur_arguments_json.rfind('": ""}')
131
- if closing_pos != -1:
132
- arguments_delta = cur_arguments_json[:closing_pos + 4]
133
- else:
134
- arguments_delta = cur_arguments_json
135
- else:
136
- arguments_delta = cur_arguments_json[lcp_len:]
137
-
138
- if not arguments_delta:
139
- return ""
140
-
141
- if not end_of_call:
142
- arguments_delta = self._strip_trailing_auto_closers(
143
- arguments_delta)
144
-
145
- if (not had_any and not end_of_call and arguments_delta
146
- and arguments_delta.endswith('}')):
147
- arguments_delta = arguments_delta[:-1]
148
- if arguments_delta.endswith('"'):
149
- arguments_delta = arguments_delta[:-1]
150
-
151
- return arguments_delta
152
-
153
- def _visible_delta_outside_tool(self, delta_text: str,
154
- start_token: Optional[str],
155
- end_token: Optional[str]) -> str:
156
- """
157
- Consume characters that could begin a tool tag. Only suppress the exact
158
- <TOOLCALL> / </TOOLCALL> sequences, and let everything else (e.g. </think>)
159
- pass through untouched.
160
- """
161
- if not delta_text:
162
- return delta_text
163
-
164
- visible: list[str] = []
165
- for ch in delta_text:
166
- if self._pending_tag_buffer or ch == '<':
167
- self._pending_tag_buffer += ch
168
-
169
- if start_token and start_token.startswith(self._pending_tag_buffer):
170
- if self._pending_tag_buffer == start_token:
171
- self._pending_tag_buffer = ""
172
- continue
173
-
174
- if end_token and end_token.startswith(self._pending_tag_buffer):
175
- if self._pending_tag_buffer == end_token:
176
- self._pending_tag_buffer = ""
177
- continue
178
-
179
- # Not a tool tag; flush buffered characters as normal content.
180
- visible.append(self._pending_tag_buffer)
181
- self._pending_tag_buffer = ""
182
- else:
183
- visible.append(ch)
184
-
185
- return "".join(visible)
186
-
187
- def adjust_request(
188
- self, request: ChatCompletionRequest) -> ChatCompletionRequest:
189
- if not isinstance(
190
- self.model_tokenizer, MistralTokenizer
191
- ) and request.tools and request.tool_choice != 'none':
192
- # Do not skip special tokens when using chat template
193
- # with Mistral parser as TOOL_CALL token is needed
194
- # for tool detection.
195
- # Note: we don't want skip_special_tokens=False
196
- # with MistralTokenizer as it is incompatible
197
- request.skip_special_tokens = False
198
- return request
199
-
200
- def extract_tool_calls(
201
- self,
202
- model_output: str,
203
- request: ChatCompletionRequest,
204
- ) -> ExtractedToolCallInformation:
205
- """
206
- Extract the tool calls from a complete model response. Requires
207
- find-and-replacing single quotes with double quotes for JSON parsing,
208
- make sure your tool call arguments don't ever include quotes!
209
- """
210
-
211
- # case -- if a tool call token is not present, return a text response
212
- if self.bot_token not in model_output:
213
- return ExtractedToolCallInformation(tools_called=False,
214
- tool_calls=[],
215
- content=model_output)
216
-
217
- # first remove the BOT token
218
- tool_content = model_output.replace(self.bot_token, "").strip()
219
-
220
- try:
221
- # we first try to directly load the json as parsing very nested
222
- # jsons is difficult
223
- try:
224
- if self.fn_name_regex:
225
- matches = self.fn_name_regex.findall(tool_content)
226
-
227
- function_call_arr = []
228
- for match in matches:
229
- fn_name = match[0]
230
- args = match[1]
231
-
232
- # fn_name is encoded outside serialized json dump
233
- # only arguments are serialized
234
- function_call_arr.append({
235
- "name": fn_name,
236
- "arguments": json.loads(args)
237
- })
238
- else:
239
- function_call_arr = json.loads(tool_content)
240
- except json.JSONDecodeError:
241
- # use a regex to find the part corresponding to the tool call.
242
- # NOTE: This use case should not happen if the model is trained
243
- # correctly. It's a easy possible fix so it's included, but
244
- # can be brittle for very complex / highly nested tool calls
245
- raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
246
- function_call_arr = json.loads(raw_tool_call)
247
-
248
- # Tool Call
249
- tool_calls: list[NemotronToolCall] = [
250
- NemotronToolCall(
251
- type="function",
252
- function=FunctionCall(
253
- name=raw_function_call["name"],
254
- # function call args are JSON but as a string
255
- arguments=json.dumps(raw_function_call["arguments"],
256
- ensure_ascii=False)))
257
- for raw_function_call in function_call_arr
258
- ]
259
-
260
- # get any content before the tool call
261
- content = model_output.split(self.bot_token)[0]
262
- return ExtractedToolCallInformation(
263
- tools_called=True,
264
- tool_calls=tool_calls,
265
- content=content if len(content) > 0 else None)
266
-
267
- except Exception:
268
- logger.exception("Error in extracting tool call from response.")
269
- # return information to just treat the tool call as regular JSON
270
- return ExtractedToolCallInformation(tools_called=False,
271
- tool_calls=[],
272
- content=tool_content)
273
-
274
- def extract_tool_calls_streaming(
275
- self,
276
- previous_text: str,
277
- current_text: str,
278
- delta_text: str,
279
- previous_token_ids: Sequence[int],
280
- current_token_ids: Sequence[int],
281
- delta_token_ids: Sequence[int],
282
- request: ChatCompletionRequest,
283
- ) -> Union[DeltaMessage, None]:
284
- # if candidates tool call tokens are in the tokens generated so far, that
285
- # means we're parsing as tool calls now. Suppress streaming if we are
286
- # currently generating any prefix of the start or end tag.
287
- visible_delta_text = delta_text
288
- try:
289
- start_token = self.bot_token
290
- end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None
291
-
292
- visible_delta_text = self._visible_delta_outside_tool(
293
- delta_text, start_token, end_token)
294
- except Exception:
295
- # Fallback to conservative checks in case of any issues
296
- if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'):
297
- return None
298
-
299
- # if the tool call token is not in the tokens generated so far, append
300
- # output to contents since it's not a tool
301
- if self.bot_token not in current_text:
302
- if visible_delta_text:
303
- return DeltaMessage(content=visible_delta_text)
304
- # still waiting on a potential tag, so emit nothing yet
305
- return None
306
-
307
- # bit mask flags for partial JSON parsing. If the name hasn't been
308
- # sent yet, don't allow sending
309
- # an incomplete string since OpenAI only ever (as far as I have
310
- # seen) allows sending the entire tool/ function name at once.
311
- flags = Allow.ALL if self.current_tool_name_sent \
312
- else Allow.ALL & ~Allow.STR
313
- end_of_call: bool = False
314
- try:
315
-
316
- # replace BOT token with empty string, and convert single quotes
317
- # to double to allow parsing as JSON since mistral uses single
318
- # quotes instead of double for tool calls
319
- parsable_arr = current_text.split(self.bot_token)[-1]
320
-
321
- # Check if we're at the end of the tool call
322
- if '</TOOLCALL>' in parsable_arr:
323
- end_of_call = True
324
- parsable_arr = parsable_arr.split('</TOOLCALL>')[0]
325
-
326
- # tool calls are generated in an array, so do partial JSON
327
- # parsing on the entire array
328
- try:
329
- tool_call_arr: list[dict] = partial_json_parser.loads(
330
- parsable_arr, flags)
331
- except (partial_json_parser.core.exceptions.MalformedJSON,
332
- json.JSONDecodeError, ValueError):
333
- return None
334
-
335
- current_tool_call: dict = tool_call_arr[self.current_tool_id] \
336
- if len(tool_call_arr) > 0 else {}
337
-
338
- # case -- if no tokens have been streamed for the tool, e.g.
339
- # only the array brackets, stream nothing
340
- if len(tool_call_arr) == 0:
341
- return None
342
-
343
- # case: we are starting a new tool in the array
344
- # -> array has > 0 length AND length has moved past cursor
345
- elif (len(tool_call_arr) > 0
346
- and len(tool_call_arr) > self.current_tool_id + 1):
347
-
348
- # if we're moving on to a new call, first make sure we
349
- # haven't missed anything in the previous one that was
350
- # auto-generated due to JSON completions, but wasn't
351
- # streamed to the client yet.
352
- if self.current_tool_id >= 0:
353
- diff: Union[str, None] = current_tool_call.get("arguments")
354
-
355
- if diff:
356
- diff = json.dumps(diff, ensure_ascii=False).replace(
357
- self.streamed_args_for_tool[self.current_tool_id],
358
- "")
359
- delta = DeltaMessage(tool_calls=[
360
- DeltaToolCall(index=self.current_tool_id,
361
- function=DeltaFunctionCall(
362
- arguments=diff).model_dump(
363
- exclude_none=True))
364
- ])
365
- self.streamed_args_for_tool[
366
- self.current_tool_id] += diff
367
- else:
368
- delta = None
369
- else:
370
- delta = None
371
- # re-set stuff pertaining to progress in the current tool
372
- self.current_tool_id = len(tool_call_arr) - 1
373
- self.current_tool_name_sent = False
374
- self.streamed_args_for_tool.append("")
375
- self.tool_args_emitted.append(False)
376
- return delta
377
-
378
- # case: update an existing tool - this is handled below
379
-
380
- # if the current tool name hasn't been sent, send if available
381
- # - otherwise send nothing
382
- if not self.current_tool_name_sent:
383
- function_name = current_tool_call.get("name")
384
- if function_name:
385
-
386
- delta = DeltaMessage(tool_calls=[
387
- DeltaToolCall(index=self.current_tool_id,
388
- type="function",
389
- id=NemotronToolCall.generate_random_id(),
390
- function=DeltaFunctionCall(
391
- name=function_name).model_dump(
392
- exclude_none=True))
393
- ])
394
- self.current_tool_name_sent = True
395
- else:
396
- delta = None
397
-
398
- # now we know we're on the same tool call and we're streaming
399
- # arguments
400
- else:
401
-
402
- prev_arguments = self.prev_tool_call_arr[
403
- self.current_tool_id].get("arguments")
404
- cur_arguments = current_tool_call.get("arguments")
405
-
406
- if not cur_arguments and not prev_arguments:
407
-
408
- delta = None
409
- elif not cur_arguments and prev_arguments:
410
- logger.error(
411
- "INVARIANT - impossible to have arguments reset "
412
- "mid-arguments")
413
- delta = None
414
- elif cur_arguments:
415
- cur_arguments_json = json.dumps(cur_arguments,
416
- ensure_ascii=False)
417
- arguments_delta = self._compute_arguments_delta(
418
- cur_arguments_json, end_of_call)
419
- if arguments_delta:
420
- delta = DeltaMessage(tool_calls=[
421
- DeltaToolCall(index=self.current_tool_id,
422
- function=DeltaFunctionCall(
423
- arguments=arguments_delta).
424
- model_dump(exclude_none=True))
425
- ])
426
- self.streamed_args_for_tool[
427
- self.current_tool_id] += arguments_delta
428
- self.tool_args_emitted[
429
- self.current_tool_id] = True
430
- else:
431
- # Do not flush final JSON here; let the serving layer
432
- # compute a minimal remaining suffix on finish.
433
- delta = None
434
- else:
435
- # End-of-call or equal state; do not force a final flush here.
436
- delta = None
437
-
438
- # check to see if the name is defined and has been sent. if so,
439
- # stream the name - otherwise keep waiting
440
- # finish by setting old and returning None as base case
441
- self.prev_tool_call_arr = tool_call_arr
442
- # If we've reached the end of a tool call, flush any remaining
443
- # suffix (including a final '}') that hasn't been streamed yet.
444
- if end_of_call and self.current_tool_id >= 0:
445
- try:
446
- cur_arguments = current_tool_call.get("arguments")
447
- if cur_arguments is not None:
448
- cur_args_json = json.dumps(cur_arguments,
449
- ensure_ascii=False)
450
- remaining_suffix = self._compute_arguments_delta(
451
- cur_args_json, end_of_call=True)
452
-
453
- # Only send remaining suffix if it's non-empty and contains meaningful content
454
- # (not just whitespace or single characters like closing braces)
455
- if remaining_suffix and remaining_suffix.strip():
456
- extra = DeltaToolCall(
457
- index=self.current_tool_id,
458
- function=DeltaFunctionCall(
459
- arguments=remaining_suffix).model_dump(
460
- exclude_none=True))
461
- if delta is None:
462
- delta = DeltaMessage(tool_calls=[extra])
463
- else:
464
- if getattr(delta, "tool_calls", None):
465
- delta.tool_calls.append(extra)
466
- else:
467
- delta.tool_calls = [extra]
468
- self.streamed_args_for_tool[
469
- self.current_tool_id] += remaining_suffix
470
- self.tool_args_emitted[self.current_tool_id] = True
471
- else:
472
- pass
473
- except Exception:
474
- pass
475
-
476
- return delta
477
-
478
- except Exception:
479
- logger.exception("Error trying to handle streaming tool call.")
480
- return None