AlexCheema commited on
Commit
5560200
·
verified ·
1 Parent(s): d7942e5

Create encoder/encoding_dsv32.py

Browse files
Files changed (1) hide show
  1. encoder/encoding_dsv32.py +376 -0
encoder/encoding_dsv32.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Optional, Tuple
2
+ import copy
3
+ import json
4
+ import re
5
+
6
+ TOOLS_SYSTEM_TEMPLATE = """## Tools
7
+
8
+ You have access to a set of tools you can use to answer the user's question.
9
+ You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user:
10
+ <{dsml_token}function_calls>
11
+ <{dsml_token}invoke name="$FUNCTION_NAME">
12
+ <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
13
+ ...
14
+ </{dsml_token}invoke>
15
+ <{dsml_token}invoke name="$FUNCTION_NAME2">
16
+ ...
17
+ </{dsml_token}invoke>
18
+ </{dsml_token}function_calls>
19
+
20
+ String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).
21
+
22
+ If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:
23
+
24
+ <{dsml_token}function_calls>
25
+ ...
26
+ </{dsml_token}function_calls>
27
+
28
+ <function_results>
29
+ ...
30
+ </function_results>
31
+
32
+ {thinking_start_token}...thinking about results{thinking_end_token}
33
+
34
+ Here are the functions available in JSONSchema format:
35
+ <functions>
36
+ {tool_schemas}
37
+ </functions>
38
+ """
39
+
40
+ bos_token: str = "<|begin▁of▁sentence|>"
41
+ eos_token: str = "<|end▁of▁sentence|>"
42
+ thinking_start_token: str = "<think>"
43
+ thinking_end_token: str = "</think>"
44
+ dsml_token: str = "|DSML|"
45
+ system_msg_template: str = "{content}"
46
+ user_msg_template: str = "<|User|>{content}<|Assistant|>"
47
+ assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>"
48
+ thinking_template = "{reasoning_content}"
49
+
50
+ response_format_template: str = (
51
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
52
+ )
53
+ tool_call_template: str = (
54
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
55
+ )
56
+ tool_calls_template = (
57
+ "<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>"
58
+ )
59
+
60
+ tool_output_template: str = (
61
+ "\n<result>{content}</result>"
62
+ )
63
+
64
+ def to_json(value: Any) -> str:
65
+ try:
66
+ return json.dumps(value, ensure_ascii=False)
67
+ except:
68
+ return json.dumps(value, ensure_ascii=True)
69
+
70
+ def tools_from_openai_format(tools):
71
+ return [tool["function"] for tool in tools]
72
+
73
+ def tool_calls_from_openai_format(tool_calls):
74
+ return [
75
+ {
76
+ "name": tool_call["function"]["name"],
77
+ "arguments": tool_call["function"]["arguments"],
78
+ }
79
+ for tool_call in tool_calls
80
+ ]
81
+
82
+ def tool_calls_to_openai_format(tool_calls):
83
+ return [
84
+ {
85
+ "type": "function",
86
+ "function": {
87
+ "name": tool_call["name"],
88
+ "arguments": tool_call["arguments"],
89
+ }
90
+ }
91
+ for tool_call in tool_calls
92
+ ]
93
+
94
+ def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
95
+ p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>"""
96
+ P_dsml_strs = []
97
+
98
+ arguments = json.loads(tool_call["arguments"])
99
+
100
+ for k, v in arguments.items():
101
+ p_dsml_str = p_dsml_template.format(
102
+ dsml_token=dsml_token,
103
+ key=k,
104
+ is_str="true" if isinstance(v, str) else "false",
105
+ value=v if isinstance(v, str) else to_json(v),
106
+ )
107
+
108
+ P_dsml_strs.append(p_dsml_str)
109
+
110
+ return "\n".join(P_dsml_strs)
111
+
112
+
113
+ def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
114
+ def _decode_value(key: str, value: str, string: str):
115
+ if string == "true":
116
+ value = to_json(value)
117
+ return f"{to_json(key)}: {value}"
118
+
119
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
120
+ return dict(name=tool_name, arguments=tool_args_json)
121
+
122
+ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
123
+ tools_json = [to_json(t) for t in tools]
124
+
125
+ return TOOLS_SYSTEM_TEMPLATE.format(
126
+ tool_schemas="\n".join(tools_json),
127
+ dsml_token=dsml_token,
128
+ thinking_start_token=thinking_start_token,
129
+ thinking_end_token=thinking_end_token,
130
+ )
131
+
132
+ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
133
+ last_user_index = -1
134
+ for idx in range(len(messages)-1, -1, -1):
135
+ if messages[idx].get("role") in ["user", "developer"]:
136
+ last_user_index = idx
137
+ break
138
+ return last_user_index
139
+
140
+ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str) -> str:
141
+ assert 0 <= index < len(messages)
142
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
143
+
144
+ prompt = ""
145
+ msg = messages[index]
146
+ last_user_idx = find_last_user_index(messages)
147
+
148
+ role = msg.get("role")
149
+ content = msg.get("content")
150
+ tools = msg.get("tools")
151
+ response_format = msg.get("response_format")
152
+ tool_calls = msg.get("tool_calls")
153
+ reasoning_content = msg.get("reasoning_content")
154
+
155
+ if tools:
156
+ tools = tools_from_openai_format(tools)
157
+ if tool_calls:
158
+ tool_calls = tool_calls_from_openai_format(tool_calls)
159
+
160
+ if role == "system":
161
+ prompt += system_msg_template.format(content=content or "")
162
+ if tools:
163
+ prompt += "\n\n" + render_tools(tools)
164
+
165
+ if response_format:
166
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
167
+
168
+ elif role == "developer":
169
+ assert content, f"Invalid message for role `{role}`: {msg}"
170
+ content_developer = ""
171
+ if tools:
172
+ content_developer += "\n\n" + render_tools(tools)
173
+
174
+ if response_format:
175
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
176
+
177
+ content_developer += "\n\n# The user's message is: {}".format(content)
178
+
179
+ prompt += user_msg_template.format(content=content_developer)
180
+ if index == last_user_idx and thinking_mode == "thinking":
181
+ prompt += thinking_start_token
182
+ else:
183
+ prompt += thinking_end_token
184
+
185
+ elif role == "user":
186
+ prompt += user_msg_template.format(content=content)
187
+
188
+ if index == last_user_idx and thinking_mode == "thinking":
189
+ prompt += thinking_start_token
190
+ else:
191
+ prompt += thinking_end_token
192
+
193
+ elif role == "tool":
194
+ prev_assistant_idx = index - 1
195
+ assistant_msg = messages[prev_assistant_idx]
196
+ while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool":
197
+ prev_assistant_idx -= 1
198
+ assistant_msg = messages[prev_assistant_idx]
199
+
200
+ assert index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant", f"Invalid messages at {index}:\n{assistant_msg}"
201
+
202
+ tool_call_order = index - prev_assistant_idx
203
+ assistant_tool_calls = assistant_msg.get("tool_calls")
204
+ assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, "No tool calls but found tool output"
205
+
206
+ if tool_call_order == 1:
207
+ prompt += "\n\n<function_results>"
208
+
209
+ prompt += tool_output_template.format(content=content)
210
+
211
+ if tool_call_order == len(assistant_tool_calls):
212
+ prompt += "\n</function_results>"
213
+
214
+ if index >= last_user_idx and thinking_mode == "thinking":
215
+ prompt += "\n\n" + thinking_start_token
216
+ else:
217
+ prompt += "\n\n" + thinking_end_token
218
+
219
+ elif role == "assistant":
220
+ prev_assistant_idx = index
221
+ thinking_part = ""
222
+
223
+ tool_calls_content = ""
224
+ if tool_calls:
225
+ tool_calls = [
226
+ tool_call_template.format(
227
+ dsml_token=dsml_token,
228
+ name=tool_call.get("name"),
229
+ arguments=encode_arguments_to_dsml(tool_call)
230
+ )
231
+ for tool_call in tool_calls
232
+ ]
233
+ tool_calls_content += "\n\n" + tool_calls_template.format(
234
+ dsml_token=dsml_token,
235
+ tool_calls="\n".join(tool_calls)
236
+ )
237
+
238
+ summary_content = content or ""
239
+
240
+ if thinking_mode == "thinking" and index > last_user_idx:
241
+ assert reasoning_content or tool_calls, f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
242
+ thinking_part = thinking_template.format(reasoning_content=reasoning_content or "") + thinking_end_token
243
+
244
+ prompt += assistant_msg_template.format(
245
+ reasoning=thinking_part,
246
+ content=summary_content,
247
+ tool_calls=tool_calls_content,
248
+ )
249
+ else:
250
+ raise NotImplementedError(f"Unknown role: {role}")
251
+
252
+ return prompt
253
+
254
+ def drop_thinking_messages(messages: List[Dict[str, Any]], last_user_idx: Optional[int]=None) -> List[Dict[str, Any]]:
255
+ messages_wo_thinking: List[Dict[str, Any]] = []
256
+ last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx
257
+ for idx, msg in enumerate(messages):
258
+ role = msg.get("role")
259
+ if role in ["user", "system", "tool"] or idx >= last_user_idx:
260
+ messages_wo_thinking.append(msg)
261
+ continue
262
+
263
+ elif role == "assistant":
264
+ msg_wo_thinking = copy.copy(msg)
265
+ msg_wo_thinking.pop("reasoning_content", None)
266
+ messages_wo_thinking.append(msg_wo_thinking)
267
+
268
+ return messages_wo_thinking
269
+
270
+ def encode_messages(messages: List[Dict[str, Any]], thinking_mode: str, context: Optional[List[Dict[str, Any]]] = None, drop_thinking: bool = True, add_default_bos_token: bool = True) -> str:
271
+ context = context if context else []
272
+ full_messages = context + messages
273
+
274
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
275
+
276
+ if thinking_mode == "thinking" and drop_thinking:
277
+ full_messages = drop_thinking_messages(full_messages)
278
+
279
+ for idx in range(len(messages)):
280
+ prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode)
281
+
282
+ return prompt
283
+
284
+ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
285
+ min_pos = len(text)
286
+ matched_stop = None
287
+
288
+ for s in stop:
289
+ pos = text.find(s, index)
290
+ if pos != -1 and pos < min_pos:
291
+ min_pos = pos
292
+ matched_stop = s
293
+
294
+ if matched_stop:
295
+ content = text[index:min_pos]
296
+ return min_pos + len(matched_stop), content, matched_stop
297
+ else:
298
+ content = text[index:]
299
+ return len(text), content, None
300
+
301
+ def parse_tool_calls(index: int, text: str):
302
+ tool_calls: List[Dict[str, Any]] = []
303
+ stop_token = None
304
+ tool_calls_end_token = f"</{dsml_token}function_calls>"
305
+
306
+ while index < len(text):
307
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
308
+ assert _ == ">\n", "Tool call format error"
309
+
310
+ if stop_token == tool_calls_end_token:
311
+ break
312
+
313
+ assert stop_token is not None, "Missing special token"
314
+
315
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
316
+
317
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
318
+ assert len(p_tool_name) == 1, "Tool name format error"
319
+ tool_name = p_tool_name[0]
320
+
321
+ tool_args: Dict[str, Tuple[str, str]] = {}
322
+ while stop_token == f"<{dsml_token}parameter":
323
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
324
+
325
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
326
+ assert len(param_kv) == 1, "Parameter format error"
327
+ param_name, string, param_value = param_kv[0]
328
+
329
+ assert param_name not in tool_args, "Duplicate parameter name"
330
+ tool_args[param_name] = (param_value, string)
331
+
332
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
333
+ assert content == ">\n", "Parameter format error"
334
+
335
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
336
+ tool_calls.append(tool_call)
337
+
338
+ return index, stop_token, tool_calls
339
+
340
+ # NOTE: This function is designed to parse only correctly formatted string and will not attempt to correct malformed output that may be generated by the model.
341
+ def parse_message_from_completion_text(text: str, thinking_mode: str):
342
+ summary_content, reasoning_content, tool_calls = "", "", []
343
+ index, stop_token = 0, None
344
+ tool_calls_start_token = f"\n\n<{dsml_token}function_calls"
345
+
346
+ is_thinking, is_tool_calling = thinking_mode == "thinking", False
347
+
348
+ if is_thinking:
349
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
350
+ reasoning_content = content_delta
351
+ assert stop_token == thinking_end_token, "Invalid thinking format"
352
+
353
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
354
+ summary_content = content_delta
355
+ if stop_token == tool_calls_start_token:
356
+ is_tool_calling = True
357
+ else:
358
+ assert stop_token == eos_token, "Invalid summary format"
359
+
360
+ if is_tool_calling:
361
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
362
+
363
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
364
+ assert not tool_ends_text, "Unexpected content after tool calls"
365
+
366
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
367
+
368
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
369
+ assert sp_token not in summary_content and sp_token not in reasoning_content, "Unexpected special token in content"
370
+
371
+ return {
372
+ "role": "assistant",
373
+ "content": summary_content,
374
+ "reasoning_content": reasoning_content,
375
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
376
+ }