Upload 32 files
Browse files- ASRDataset.py +56 -9
- README.md +536 -0
- chat_template.json +1 -3
- config.json +1 -1
- deploy/__pycache__/main.cpython-310.pyc +0 -0
- deploy/log.txt +0 -0
- deploy/main.py +198 -0
- deploy/testapi.ipynb +172 -0
- eval.py +22 -23
- eval_multiturn_textonly.py +131 -0
- generation_config.json +13 -0
- model-00001-of-00003.safetensors +2 -2
- model-00002-of-00003.safetensors +2 -2
- model-00003-of-00003.safetensors +2 -2
- model.safetensors.index.json +0 -0
- preprocessing_gemma3omni.py +1 -13
- run.sh +4 -0
- template.json +1 -0
- test.ipynb +977 -0
- test.jinja +51 -0
- tmp.sh +3 -0
- training.py +60 -128
ASRDataset.py
CHANGED
|
@@ -427,6 +427,47 @@ class TWCostumData(BaseAudioDataset):
|
|
| 427 |
self.instruction,
|
| 428 |
answer_text
|
| 429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
def covost_collate_fn(batch):
|
| 431 |
input_ids_list = []
|
| 432 |
labels_list = []
|
|
@@ -556,11 +597,12 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 556 |
self.text_only=text_only
|
| 557 |
with open(json_path) as f:
|
| 558 |
js_data = json.load(f)
|
|
|
|
| 559 |
if split=='train':
|
| 560 |
self.train=True
|
| 561 |
-
js_data = js_data[:int(len(js_data)
|
| 562 |
else:
|
| 563 |
-
js_data = js_data[-
|
| 564 |
for conv in js_data:
|
| 565 |
for mess in conv['conversations']:
|
| 566 |
if 'audio_path' in mess:
|
|
@@ -570,7 +612,7 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 570 |
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
| 571 |
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
| 572 |
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
| 573 |
-
format_function=FunctionFormatter(slots=["{{content}}"
|
| 574 |
format_tools = ToolFormatter(tool_format="default"),
|
| 575 |
format_observation=StringFormatter(
|
| 576 |
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
|
@@ -583,6 +625,7 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 583 |
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 584 |
stop_words=["<end_of_turn>"],
|
| 585 |
mm_plugin=get_mm_plugin(name="base"),
|
|
|
|
| 586 |
)
|
| 587 |
|
| 588 |
self.set_dataset_name(f"MultiturnCostumData")
|
|
@@ -603,9 +646,10 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 603 |
if mess['from']=='human':
|
| 604 |
tmp['messages'].append(conv['conversations'][i+1])
|
| 605 |
d = deepcopy(tmp)
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
|
|
|
| 609 |
self.text_only_data.append(deepcopy(tmp))
|
| 610 |
tmp['messages'].pop()
|
| 611 |
elif mess['from']=='observation':
|
|
@@ -632,14 +676,16 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 632 |
if system or tools:
|
| 633 |
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
|
| 634 |
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
|
|
|
|
| 635 |
|
| 636 |
if message["from"] == "human":
|
| 637 |
if i==len(messages)-2 and not self.text_only:
|
| 638 |
user_transcribe = message["value"]
|
| 639 |
-
elements += self.template.format_user.apply(content=
|
| 640 |
else:
|
| 641 |
-
elements += self.template.format_user.apply(content=
|
| 642 |
-
|
|
|
|
| 643 |
elif message["from"] == "gpt":
|
| 644 |
elements += self.template.format_assistant.apply(content=message["value"])
|
| 645 |
elif message["from"] == "observation":
|
|
@@ -713,6 +759,7 @@ class MultiturnAudioDataset(BaseAudioDataset):
|
|
| 713 |
'input_audio_embeds': inputs.input_audio_embeds,
|
| 714 |
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 715 |
'input_modes': inputs.input_modes,
|
|
|
|
| 716 |
}
|
| 717 |
else:
|
| 718 |
return {
|
|
|
|
| 427 |
self.instruction,
|
| 428 |
answer_text
|
| 429 |
)
|
| 430 |
+
|
| 431 |
+
class TWCostumDataTasks(BaseAudioDataset):
|
| 432 |
+
|
| 433 |
+
def __init__(self, processor, split="train", sampling_rate=16000,json_path="", debug=False):
|
| 434 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 435 |
+
import pandas as pd
|
| 436 |
+
from datasets import Dataset, Audio
|
| 437 |
+
|
| 438 |
+
with open(json_path) as f:
|
| 439 |
+
js_data = json.load(f)
|
| 440 |
+
|
| 441 |
+
raw_data = {
|
| 442 |
+
"audio": [],
|
| 443 |
+
"sentence": []
|
| 444 |
+
}
|
| 445 |
+
for conv in js_data:
|
| 446 |
+
for mess in conv['conversations']:
|
| 447 |
+
if 'audio_path' in mess:
|
| 448 |
+
raw_data['audio'].append(mess['audio_path'])
|
| 449 |
+
raw_data['sentence'].append(mess["value"])
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
self.set_dataset_name("TWCostumDataTasks"+json_path)
|
| 453 |
+
self.data = Dataset.from_dict(raw_data).cast_column("audio", Audio(sampling_rate=16000))
|
| 454 |
+
|
| 455 |
+
# Instruction Setting
|
| 456 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 457 |
+
|
| 458 |
+
def __len__(self):
|
| 459 |
+
return len(self.data)
|
| 460 |
+
|
| 461 |
+
def __getitem__(self, idx):
|
| 462 |
+
data = self.data[idx]
|
| 463 |
+
|
| 464 |
+
answer_text = data["sentence"]
|
| 465 |
+
return self.prepare_model_inputs(
|
| 466 |
+
data["audio"]["array"],
|
| 467 |
+
self.instruction,
|
| 468 |
+
answer_text
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
def covost_collate_fn(batch):
|
| 472 |
input_ids_list = []
|
| 473 |
labels_list = []
|
|
|
|
| 597 |
self.text_only=text_only
|
| 598 |
with open(json_path) as f:
|
| 599 |
js_data = json.load(f)
|
| 600 |
+
test_len = min(len(js_data)*0.2,200)
|
| 601 |
if split=='train':
|
| 602 |
self.train=True
|
| 603 |
+
js_data = js_data[:int(len(js_data)-test_len)]
|
| 604 |
else:
|
| 605 |
+
js_data = js_data[-test_len:]
|
| 606 |
for conv in js_data:
|
| 607 |
for mess in conv['conversations']:
|
| 608 |
if 'audio_path' in mess:
|
|
|
|
| 612 |
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
| 613 |
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
| 614 |
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
| 615 |
+
format_function=FunctionFormatter(slots=["{{content}}<end_of_turn>\n"], tool_format="default"),
|
| 616 |
format_tools = ToolFormatter(tool_format="default"),
|
| 617 |
format_observation=StringFormatter(
|
| 618 |
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
|
|
|
| 625 |
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 626 |
stop_words=["<end_of_turn>"],
|
| 627 |
mm_plugin=get_mm_plugin(name="base"),
|
| 628 |
+
enable_thinking=False
|
| 629 |
)
|
| 630 |
|
| 631 |
self.set_dataset_name(f"MultiturnCostumData")
|
|
|
|
| 646 |
if mess['from']=='human':
|
| 647 |
tmp['messages'].append(conv['conversations'][i+1])
|
| 648 |
d = deepcopy(tmp)
|
| 649 |
+
if not self.text_only and 'audio_path' in mess:
|
| 650 |
+
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
|
| 651 |
+
self.data.append(d)
|
| 652 |
+
else:
|
| 653 |
self.text_only_data.append(deepcopy(tmp))
|
| 654 |
tmp['messages'].pop()
|
| 655 |
elif mess['from']=='observation':
|
|
|
|
| 676 |
if system or tools:
|
| 677 |
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
|
| 678 |
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
|
| 679 |
+
elements += system_text
|
| 680 |
|
| 681 |
if message["from"] == "human":
|
| 682 |
if i==len(messages)-2 and not self.text_only:
|
| 683 |
user_transcribe = message["value"]
|
| 684 |
+
elements += self.template.format_user.apply(content='<start_of_audio>')
|
| 685 |
else:
|
| 686 |
+
elements += self.template.format_user.apply(content=message["value"])
|
| 687 |
+
if not self.text_only:
|
| 688 |
+
audio_paths.append(message['audio_path'])
|
| 689 |
elif message["from"] == "gpt":
|
| 690 |
elements += self.template.format_assistant.apply(content=message["value"])
|
| 691 |
elif message["from"] == "observation":
|
|
|
|
| 759 |
'input_audio_embeds': inputs.input_audio_embeds,
|
| 760 |
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 761 |
'input_modes': inputs.input_modes,
|
| 762 |
+
|
| 763 |
}
|
| 764 |
else:
|
| 765 |
return {
|
README.md
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: gemma
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: image-text-to-text
|
| 5 |
+
extra_gated_heading: Access Gemma on Hugging Face
|
| 6 |
+
extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and
|
| 7 |
+
agree to Google’s usage license. To do this, please ensure you’re logged in to Hugging
|
| 8 |
+
Face and click below. Requests are processed immediately.
|
| 9 |
+
extra_gated_button_content: Acknowledge license
|
| 10 |
+
base_model: google/gemma-3-4b-pt
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Gemma 3 model card
|
| 14 |
+
|
| 15 |
+
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs/core)
|
| 16 |
+
|
| 17 |
+
**Resources and Technical Documentation**:
|
| 18 |
+
|
| 19 |
+
* [Gemma 3 Technical Report][g3-tech-report]
|
| 20 |
+
* [Responsible Generative AI Toolkit][rai-toolkit]
|
| 21 |
+
* [Gemma on Kaggle][kaggle-gemma]
|
| 22 |
+
* [Gemma on Vertex Model Garden][vertex-mg-gemma3]
|
| 23 |
+
|
| 24 |
+
**Terms of Use**: [Terms][terms]
|
| 25 |
+
|
| 26 |
+
**Authors**: Google DeepMind
|
| 27 |
+
|
| 28 |
+
## Model Information
|
| 29 |
+
|
| 30 |
+
Summary description and brief definition of inputs and outputs.
|
| 31 |
+
|
| 32 |
+
### Description
|
| 33 |
+
|
| 34 |
+
Gemma is a family of lightweight, state-of-the-art open models from Google,
|
| 35 |
+
built from the same research and technology used to create the Gemini models.
|
| 36 |
+
Gemma 3 models are multimodal, handling text and image input and generating text
|
| 37 |
+
output, with open weights for both pre-trained variants and instruction-tuned
|
| 38 |
+
variants. Gemma 3 has a large, 128K context window, multilingual support in over
|
| 39 |
+
140 languages, and is available in more sizes than previous versions. Gemma 3
|
| 40 |
+
models are well-suited for a variety of text generation and image understanding
|
| 41 |
+
tasks, including question answering, summarization, and reasoning. Their
|
| 42 |
+
relatively small size makes it possible to deploy them in environments with
|
| 43 |
+
limited resources such as laptops, desktops or your own cloud infrastructure,
|
| 44 |
+
democratizing access to state of the art AI models and helping foster innovation
|
| 45 |
+
for everyone.
|
| 46 |
+
|
| 47 |
+
### Inputs and outputs
|
| 48 |
+
|
| 49 |
+
- **Input:**
|
| 50 |
+
- Text string, such as a question, a prompt, or a document to be summarized
|
| 51 |
+
- Images, normalized to 896 x 896 resolution and encoded to 256 tokens
|
| 52 |
+
each
|
| 53 |
+
- Total input context of 128K tokens for the 4B, 12B, and 27B sizes, and
|
| 54 |
+
32K tokens for the 1B size
|
| 55 |
+
|
| 56 |
+
- **Output:**
|
| 57 |
+
- Generated text in response to the input, such as an answer to a
|
| 58 |
+
question, analysis of image content, or a summary of a document
|
| 59 |
+
- Total output context of 8192 tokens
|
| 60 |
+
|
| 61 |
+
### Usage
|
| 62 |
+
|
| 63 |
+
Below, there are some code snippets on how to get quickly started with running the model. First, install the Transformers library. Gemma 3 is supported starting from transformers 4.50.0.
|
| 64 |
+
|
| 65 |
+
```sh
|
| 66 |
+
$ pip install -U transformers
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Then, copy the snippet from the section that is relevant for your use case.
|
| 70 |
+
|
| 71 |
+
#### Running with the `pipeline` API
|
| 72 |
+
|
| 73 |
+
You can initialize the model and processor for inference with `pipeline` as follows.
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from transformers import pipeline
|
| 77 |
+
import torch
|
| 78 |
+
|
| 79 |
+
pipe = pipeline(
|
| 80 |
+
"image-text-to-text",
|
| 81 |
+
model="google/gemma-3-4b-it",
|
| 82 |
+
device="cuda",
|
| 83 |
+
torch_dtype=torch.bfloat16
|
| 84 |
+
)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
With instruction-tuned models, you need to use chat templates to process our inputs first. Then, you can pass it to the pipeline.
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
messages = [
|
| 91 |
+
{
|
| 92 |
+
"role": "system",
|
| 93 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"role": "user",
|
| 97 |
+
"content": [
|
| 98 |
+
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
|
| 99 |
+
{"type": "text", "text": "What animal is on the candy?"}
|
| 100 |
+
]
|
| 101 |
+
}
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
output = pipe(text=messages, max_new_tokens=200)
|
| 105 |
+
print(output[0]["generated_text"][-1]["content"])
|
| 106 |
+
# Okay, let's take a look!
|
| 107 |
+
# Based on the image, the animal on the candy is a **turtle**.
|
| 108 |
+
# You can see the shell shape and the head and legs.
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
#### Running the model on a single/multi GPU
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
# pip install accelerate
|
| 115 |
+
|
| 116 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 117 |
+
from PIL import Image
|
| 118 |
+
import requests
|
| 119 |
+
import torch
|
| 120 |
+
|
| 121 |
+
model_id = "google/gemma-3-4b-it"
|
| 122 |
+
|
| 123 |
+
model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 124 |
+
model_id, device_map="auto"
|
| 125 |
+
).eval()
|
| 126 |
+
|
| 127 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 128 |
+
|
| 129 |
+
messages = [
|
| 130 |
+
{
|
| 131 |
+
"role": "system",
|
| 132 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"role": "user",
|
| 136 |
+
"content": [
|
| 137 |
+
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
| 138 |
+
{"type": "text", "text": "Describe this image in detail."}
|
| 139 |
+
]
|
| 140 |
+
}
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
inputs = processor.apply_chat_template(
|
| 144 |
+
messages, add_generation_prompt=True, tokenize=True,
|
| 145 |
+
return_dict=True, return_tensors="pt"
|
| 146 |
+
).to(model.device, dtype=torch.bfloat16)
|
| 147 |
+
|
| 148 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 149 |
+
|
| 150 |
+
with torch.inference_mode():
|
| 151 |
+
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
| 152 |
+
generation = generation[0][input_len:]
|
| 153 |
+
|
| 154 |
+
decoded = processor.decode(generation, skip_special_tokens=True)
|
| 155 |
+
print(decoded)
|
| 156 |
+
|
| 157 |
+
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
|
| 158 |
+
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
|
| 159 |
+
# It has a slightly soft, natural feel, likely captured in daylight.
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
### Citation
|
| 164 |
+
|
| 165 |
+
```none
|
| 166 |
+
@article{gemma_2025,
|
| 167 |
+
title={Gemma 3},
|
| 168 |
+
url={https://goo.gle/Gemma3Report},
|
| 169 |
+
publisher={Kaggle},
|
| 170 |
+
author={Gemma Team},
|
| 171 |
+
year={2025}
|
| 172 |
+
}
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Model Data
|
| 176 |
+
|
| 177 |
+
Data used for model training and how the data was processed.
|
| 178 |
+
|
| 179 |
+
### Training Dataset
|
| 180 |
+
|
| 181 |
+
These models were trained on a dataset of text data that includes a wide variety
|
| 182 |
+
of sources. The 27B model was trained with 14 trillion tokens, the 12B model was
|
| 183 |
+
trained with 12 trillion tokens, 4B model was trained with 4 trillion tokens and
|
| 184 |
+
1B with 2 trillion tokens. Here are the key components:
|
| 185 |
+
|
| 186 |
+
- Web Documents: A diverse collection of web text ensures the model is
|
| 187 |
+
exposed to a broad range of linguistic styles, topics, and vocabulary. The
|
| 188 |
+
training dataset includes content in over 140 languages.
|
| 189 |
+
- Code: Exposing the model to code helps it to learn the syntax and
|
| 190 |
+
patterns of programming languages, which improves its ability to generate
|
| 191 |
+
code and understand code-related questions.
|
| 192 |
+
- Mathematics: Training on mathematical text helps the model learn logical
|
| 193 |
+
reasoning, symbolic representation, and to address mathematical queries.
|
| 194 |
+
- Images: A wide range of images enables the model to perform image
|
| 195 |
+
analysis and visual data extraction tasks.
|
| 196 |
+
|
| 197 |
+
The combination of these diverse data sources is crucial for training a powerful
|
| 198 |
+
multimodal model that can handle a wide variety of different tasks and data
|
| 199 |
+
formats.
|
| 200 |
+
|
| 201 |
+
### Data Preprocessing
|
| 202 |
+
|
| 203 |
+
Here are the key data cleaning and filtering methods applied to the training
|
| 204 |
+
data:
|
| 205 |
+
|
| 206 |
+
- CSAM Filtering: Rigorous CSAM (Child Sexual Abuse Material) filtering
|
| 207 |
+
was applied at multiple stages in the data preparation process to ensure
|
| 208 |
+
the exclusion of harmful and illegal content.
|
| 209 |
+
- Sensitive Data Filtering: As part of making Gemma pre-trained models
|
| 210 |
+
safe and reliable, automated techniques were used to filter out certain
|
| 211 |
+
personal information and other sensitive data from training sets.
|
| 212 |
+
- Additional methods: Filtering based on content quality and safety in
|
| 213 |
+
line with [our policies][safety-policies].
|
| 214 |
+
|
| 215 |
+
## Implementation Information
|
| 216 |
+
|
| 217 |
+
Details about the model internals.
|
| 218 |
+
|
| 219 |
+
### Hardware
|
| 220 |
+
|
| 221 |
+
Gemma was trained using [Tensor Processing Unit (TPU)][tpu] hardware (TPUv4p,
|
| 222 |
+
TPUv5p and TPUv5e). Training vision-language models (VLMS) requires significant
|
| 223 |
+
computational power. TPUs, designed specifically for matrix operations common in
|
| 224 |
+
machine learning, offer several advantages in this domain:
|
| 225 |
+
|
| 226 |
+
- Performance: TPUs are specifically designed to handle the massive
|
| 227 |
+
computations involved in training VLMs. They can speed up training
|
| 228 |
+
considerably compared to CPUs.
|
| 229 |
+
- Memory: TPUs often come with large amounts of high-bandwidth memory,
|
| 230 |
+
allowing for the handling of large models and batch sizes during training.
|
| 231 |
+
This can lead to better model quality.
|
| 232 |
+
- Scalability: TPU Pods (large clusters of TPUs) provide a scalable
|
| 233 |
+
solution for handling the growing complexity of large foundation models.
|
| 234 |
+
You can distribute training across multiple TPU devices for faster and more
|
| 235 |
+
efficient processing.
|
| 236 |
+
- Cost-effectiveness: In many scenarios, TPUs can provide a more
|
| 237 |
+
cost-effective solution for training large models compared to CPU-based
|
| 238 |
+
infrastructure, especially when considering the time and resources saved
|
| 239 |
+
due to faster training.
|
| 240 |
+
- These advantages are aligned with
|
| 241 |
+
[Google's commitments to operate sustainably][sustainability].
|
| 242 |
+
|
| 243 |
+
### Software
|
| 244 |
+
|
| 245 |
+
Training was done using [JAX][jax] and [ML Pathways][ml-pathways].
|
| 246 |
+
|
| 247 |
+
JAX allows researchers to take advantage of the latest generation of hardware,
|
| 248 |
+
including TPUs, for faster and more efficient training of large models. ML
|
| 249 |
+
Pathways is Google's latest effort to build artificially intelligent systems
|
| 250 |
+
capable of generalizing across multiple tasks. This is specially suitable for
|
| 251 |
+
foundation models, including large language models like these ones.
|
| 252 |
+
|
| 253 |
+
Together, JAX and ML Pathways are used as described in the
|
| 254 |
+
[paper about the Gemini family of models][gemini-2-paper]; *"the 'single
|
| 255 |
+
controller' programming model of Jax and Pathways allows a single Python
|
| 256 |
+
process to orchestrate the entire training run, dramatically simplifying the
|
| 257 |
+
development workflow."*
|
| 258 |
+
|
| 259 |
+
## Evaluation
|
| 260 |
+
|
| 261 |
+
Model evaluation metrics and results.
|
| 262 |
+
|
| 263 |
+
### Benchmark Results
|
| 264 |
+
|
| 265 |
+
These models were evaluated against a large collection of different datasets and
|
| 266 |
+
metrics to cover different aspects of text generation:
|
| 267 |
+
|
| 268 |
+
#### Reasoning and factuality
|
| 269 |
+
|
| 270 |
+
| Benchmark | Metric | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 271 |
+
| ------------------------------ |----------------|:--------------:|:-------------:|:--------------:|:--------------:|
|
| 272 |
+
| [HellaSwag][hellaswag] | 10-shot | 62.3 | 77.2 | 84.2 | 85.6 |
|
| 273 |
+
| [BoolQ][boolq] | 0-shot | 63.2 | 72.3 | 78.8 | 82.4 |
|
| 274 |
+
| [PIQA][piqa] | 0-shot | 73.8 | 79.6 | 81.8 | 83.3 |
|
| 275 |
+
| [SocialIQA][socialiqa] | 0-shot | 48.9 | 51.9 | 53.4 | 54.9 |
|
| 276 |
+
| [TriviaQA][triviaqa] | 5-shot | 39.8 | 65.8 | 78.2 | 85.5 |
|
| 277 |
+
| [Natural Questions][naturalq] | 5-shot | 9.48 | 20.0 | 31.4 | 36.1 |
|
| 278 |
+
| [ARC-c][arc] | 25-shot | 38.4 | 56.2 | 68.9 | 70.6 |
|
| 279 |
+
| [ARC-e][arc] | 0-shot | 73.0 | 82.4 | 88.3 | 89.0 |
|
| 280 |
+
| [WinoGrande][winogrande] | 5-shot | 58.2 | 64.7 | 74.3 | 78.8 |
|
| 281 |
+
| [BIG-Bench Hard][bbh] | few-shot | 28.4 | 50.9 | 72.6 | 77.7 |
|
| 282 |
+
| [DROP][drop] | 1-shot | 42.4 | 60.1 | 72.2 | 77.2 |
|
| 283 |
+
|
| 284 |
+
[hellaswag]: https://arxiv.org/abs/1905.07830
|
| 285 |
+
[boolq]: https://arxiv.org/abs/1905.10044
|
| 286 |
+
[piqa]: https://arxiv.org/abs/1911.11641
|
| 287 |
+
[socialiqa]: https://arxiv.org/abs/1904.09728
|
| 288 |
+
[triviaqa]: https://arxiv.org/abs/1705.03551
|
| 289 |
+
[naturalq]: https://github.com/google-research-datasets/natural-questions
|
| 290 |
+
[arc]: https://arxiv.org/abs/1911.01547
|
| 291 |
+
[winogrande]: https://arxiv.org/abs/1907.10641
|
| 292 |
+
[bbh]: https://paperswithcode.com/dataset/bbh
|
| 293 |
+
[drop]: https://arxiv.org/abs/1903.00161
|
| 294 |
+
|
| 295 |
+
#### STEM and code
|
| 296 |
+
|
| 297 |
+
| Benchmark | Metric | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 298 |
+
| ------------------------------ |----------------|:-------------:|:--------------:|:--------------:|
|
| 299 |
+
| [MMLU][mmlu] | 5-shot | 59.6 | 74.5 | 78.6 |
|
| 300 |
+
| [MMLU][mmlu] (Pro COT) | 5-shot | 29.2 | 45.3 | 52.2 |
|
| 301 |
+
| [AGIEval][agieval] | 3-5-shot | 42.1 | 57.4 | 66.2 |
|
| 302 |
+
| [MATH][math] | 4-shot | 24.2 | 43.3 | 50.0 |
|
| 303 |
+
| [GSM8K][gsm8k] | 8-shot | 38.4 | 71.0 | 82.6 |
|
| 304 |
+
| [GPQA][gpqa] | 5-shot | 15.0 | 25.4 | 24.3 |
|
| 305 |
+
| [MBPP][mbpp] | 3-shot | 46.0 | 60.4 | 65.6 |
|
| 306 |
+
| [HumanEval][humaneval] | 0-shot | 36.0 | 45.7 | 48.8 |
|
| 307 |
+
|
| 308 |
+
[mmlu]: https://arxiv.org/abs/2009.03300
|
| 309 |
+
[agieval]: https://arxiv.org/abs/2304.06364
|
| 310 |
+
[math]: https://arxiv.org/abs/2103.03874
|
| 311 |
+
[gsm8k]: https://arxiv.org/abs/2110.14168
|
| 312 |
+
[gpqa]: https://arxiv.org/abs/2311.12022
|
| 313 |
+
[mbpp]: https://arxiv.org/abs/2108.07732
|
| 314 |
+
[humaneval]: https://arxiv.org/abs/2107.03374
|
| 315 |
+
|
| 316 |
+
#### Multilingual
|
| 317 |
+
|
| 318 |
+
| Benchmark | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 319 |
+
| ------------------------------------ |:-------------:|:-------------:|:--------------:|:--------------:|
|
| 320 |
+
| [MGSM][mgsm] | 2.04 | 34.7 | 64.3 | 74.3 |
|
| 321 |
+
| [Global-MMLU-Lite][global-mmlu-lite] | 24.9 | 57.0 | 69.4 | 75.7 |
|
| 322 |
+
| [WMT24++][wmt24pp] (ChrF) | 36.7 | 48.4 | 53.9 | 55.7 |
|
| 323 |
+
| [FloRes][flores] | 29.5 | 39.2 | 46.0 | 48.8 |
|
| 324 |
+
| [XQuAD][xquad] (all) | 43.9 | 68.0 | 74.5 | 76.8 |
|
| 325 |
+
| [ECLeKTic][eclektic] | 4.69 | 11.0 | 17.2 | 24.4 |
|
| 326 |
+
| [IndicGenBench][indicgenbench] | 41.4 | 57.2 | 61.7 | 63.4 |
|
| 327 |
+
|
| 328 |
+
[mgsm]: https://arxiv.org/abs/2210.03057
|
| 329 |
+
[flores]: https://arxiv.org/abs/2106.03193
|
| 330 |
+
[xquad]: https://arxiv.org/abs/1910.11856v3
|
| 331 |
+
[global-mmlu-lite]: https://huggingface.co/datasets/CohereForAI/Global-MMLU-Lite
|
| 332 |
+
[wmt24pp]: https://arxiv.org/abs/2502.12404v1
|
| 333 |
+
[eclektic]: https://arxiv.org/abs/2502.21228
|
| 334 |
+
[indicgenbench]: https://arxiv.org/abs/2404.16816
|
| 335 |
+
|
| 336 |
+
#### Multimodal
|
| 337 |
+
|
| 338 |
+
| Benchmark | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 339 |
+
| ------------------------------ |:-------------:|:--------------:|:--------------:|
|
| 340 |
+
| [COCOcap][coco-cap] | 102 | 111 | 116 |
|
| 341 |
+
| [DocVQA][docvqa] (val) | 72.8 | 82.3 | 85.6 |
|
| 342 |
+
| [InfoVQA][info-vqa] (val) | 44.1 | 54.8 | 59.4 |
|
| 343 |
+
| [MMMU][mmmu] (pt) | 39.2 | 50.3 | 56.1 |
|
| 344 |
+
| [TextVQA][textvqa] (val) | 58.9 | 66.5 | 68.6 |
|
| 345 |
+
| [RealWorldQA][realworldqa] | 45.5 | 52.2 | 53.9 |
|
| 346 |
+
| [ReMI][remi] | 27.3 | 38.5 | 44.8 |
|
| 347 |
+
| [AI2D][ai2d] | 63.2 | 75.2 | 79.0 |
|
| 348 |
+
| [ChartQA][chartqa] | 63.6 | 74.7 | 76.3 |
|
| 349 |
+
| [VQAv2][vqav2] | 63.9 | 71.2 | 72.9 |
|
| 350 |
+
| [BLINK][blinkvqa] | 38.0 | 35.9 | 39.6 |
|
| 351 |
+
| [OKVQA][okvqa] | 51.0 | 58.7 | 60.2 |
|
| 352 |
+
| [TallyQA][tallyqa] | 42.5 | 51.8 | 54.3 |
|
| 353 |
+
| [SpatialSense VQA][ss-vqa] | 50.9 | 60.0 | 59.4 |
|
| 354 |
+
| [CountBenchQA][countbenchqa] | 26.1 | 17.8 | 68.0 |
|
| 355 |
+
|
| 356 |
+
[coco-cap]: https://cocodataset.org/#home
|
| 357 |
+
[docvqa]: https://www.docvqa.org/
|
| 358 |
+
[info-vqa]: https://arxiv.org/abs/2104.12756
|
| 359 |
+
[mmmu]: https://arxiv.org/abs/2311.16502
|
| 360 |
+
[textvqa]: https://textvqa.org/
|
| 361 |
+
[realworldqa]: https://paperswithcode.com/dataset/realworldqa
|
| 362 |
+
[remi]: https://arxiv.org/html/2406.09175v1
|
| 363 |
+
[ai2d]: https://allenai.org/data/diagrams
|
| 364 |
+
[chartqa]: https://arxiv.org/abs/2203.10244
|
| 365 |
+
[vqav2]: https://visualqa.org/index.html
|
| 366 |
+
[blinkvqa]: https://arxiv.org/abs/2404.12390
|
| 367 |
+
[okvqa]: https://okvqa.allenai.org/
|
| 368 |
+
[tallyqa]: https://arxiv.org/abs/1810.12440
|
| 369 |
+
[ss-vqa]: https://arxiv.org/abs/1908.02660
|
| 370 |
+
[countbenchqa]: https://github.com/google-research/big_vision/blob/main/big_vision/datasets/countbenchqa/
|
| 371 |
+
|
| 372 |
+
## Ethics and Safety
|
| 373 |
+
|
| 374 |
+
Ethics and safety evaluation approach and results.
|
| 375 |
+
|
| 376 |
+
### Evaluation Approach
|
| 377 |
+
|
| 378 |
+
Our evaluation methods include structured evaluations and internal red-teaming
|
| 379 |
+
testing of relevant content policies. Red-teaming was conducted by a number of
|
| 380 |
+
different teams, each with different goals and human evaluation metrics. These
|
| 381 |
+
models were evaluated against a number of different categories relevant to
|
| 382 |
+
ethics and safety, including:
|
| 383 |
+
|
| 384 |
+
- **Child Safety**: Evaluation of text-to-text and image to text prompts
|
| 385 |
+
covering child safety policies, including child sexual abuse and
|
| 386 |
+
exploitation.
|
| 387 |
+
- **Content Safety:** Evaluation of text-to-text and image to text prompts
|
| 388 |
+
covering safety policies including, harassment, violence and gore, and hate
|
| 389 |
+
speech.
|
| 390 |
+
- **Representational Harms**: Evaluation of text-to-text and image to text
|
| 391 |
+
prompts covering safety policies including bias, stereotyping, and harmful
|
| 392 |
+
associations or inaccuracies.
|
| 393 |
+
|
| 394 |
+
In addition to development level evaluations, we conduct "assurance
|
| 395 |
+
evaluations" which are our 'arms-length' internal evaluations for responsibility
|
| 396 |
+
governance decision making. They are conducted separately from the model
|
| 397 |
+
development team, to inform decision making about release. High level findings
|
| 398 |
+
are fed back to the model team, but prompt sets are held-out to prevent
|
| 399 |
+
overfitting and preserve the results' ability to inform decision making.
|
| 400 |
+
Assurance evaluation results are reported to our Responsibility & Safety Council
|
| 401 |
+
as part of release review.
|
| 402 |
+
|
| 403 |
+
### Evaluation Results
|
| 404 |
+
|
| 405 |
+
For all areas of safety testing, we saw major improvements in the categories of
|
| 406 |
+
child safety, content safety, and representational harms relative to previous
|
| 407 |
+
Gemma models. All testing was conducted without safety filters to evaluate the
|
| 408 |
+
model capabilities and behaviors. For both text-to-text and image-to-text, and
|
| 409 |
+
across all model sizes, the model produced minimal policy violations, and showed
|
| 410 |
+
significant improvements over previous Gemma models' performance with respect
|
| 411 |
+
to ungrounded inferences. A limitation of our evaluations was they included only
|
| 412 |
+
English language prompts.
|
| 413 |
+
|
| 414 |
+
## Usage and Limitations
|
| 415 |
+
|
| 416 |
+
These models have certain limitations that users should be aware of.
|
| 417 |
+
|
| 418 |
+
### Intended Usage
|
| 419 |
+
|
| 420 |
+
Open vision-language models (VLMs) models have a wide range of applications
|
| 421 |
+
across various industries and domains. The following list of potential uses is
|
| 422 |
+
not comprehensive. The purpose of this list is to provide contextual information
|
| 423 |
+
about the possible use-cases that the model creators considered as part of model
|
| 424 |
+
training and development.
|
| 425 |
+
|
| 426 |
+
- Content Creation and Communication
|
| 427 |
+
- Text Generation: These models can be used to generate creative text
|
| 428 |
+
formats such as poems, scripts, code, marketing copy, and email drafts.
|
| 429 |
+
- Chatbots and Conversational AI: Power conversational interfaces
|
| 430 |
+
for customer service, virtual assistants, or interactive applications.
|
| 431 |
+
- Text Summarization: Generate concise summaries of a text corpus,
|
| 432 |
+
research papers, or reports.
|
| 433 |
+
- Image Data Extraction: These models can be used to extract,
|
| 434 |
+
interpret, and summarize visual data for text communications.
|
| 435 |
+
- Research and Education
|
| 436 |
+
- Natural Language Processing (NLP) and VLM Research: These
|
| 437 |
+
models can serve as a foundation for researchers to experiment with VLM
|
| 438 |
+
and NLP techniques, develop algorithms, and contribute to the
|
| 439 |
+
advancement of the field.
|
| 440 |
+
- Language Learning Tools: Support interactive language learning
|
| 441 |
+
experiences, aiding in grammar correction or providing writing practice.
|
| 442 |
+
- Knowledge Exploration: Assist researchers in exploring large
|
| 443 |
+
bodies of text by generating summaries or answering questions about
|
| 444 |
+
specific topics.
|
| 445 |
+
|
| 446 |
+
### Limitations
|
| 447 |
+
|
| 448 |
+
- Training Data
|
| 449 |
+
- The quality and diversity of the training data significantly
|
| 450 |
+
influence the model's capabilities. Biases or gaps in the training data
|
| 451 |
+
can lead to limitations in the model's responses.
|
| 452 |
+
- The scope of the training dataset determines the subject areas
|
| 453 |
+
the model can handle effectively.
|
| 454 |
+
- Context and Task Complexity
|
| 455 |
+
- Models are better at tasks that can be framed with clear
|
| 456 |
+
prompts and instructions. Open-ended or highly complex tasks might be
|
| 457 |
+
challenging.
|
| 458 |
+
- A model's performance can be influenced by the amount of context
|
| 459 |
+
provided (longer context generally leads to better outputs, up to a
|
| 460 |
+
certain point).
|
| 461 |
+
- Language Ambiguity and Nuance
|
| 462 |
+
- Natural language is inherently complex. Models might struggle
|
| 463 |
+
to grasp subtle nuances, sarcasm, or figurative language.
|
| 464 |
+
- Factual Accuracy
|
| 465 |
+
- Models generate responses based on information they learned
|
| 466 |
+
from their training datasets, but they are not knowledge bases. They
|
| 467 |
+
may generate incorrect or outdated factual statements.
|
| 468 |
+
- Common Sense
|
| 469 |
+
- Models rely on statistical patterns in language. They might
|
| 470 |
+
lack the ability to apply common sense reasoning in certain situations.
|
| 471 |
+
|
| 472 |
+
### Ethical Considerations and Risks
|
| 473 |
+
|
| 474 |
+
The development of vision-language models (VLMs) raises several ethical
|
| 475 |
+
concerns. In creating an open model, we have carefully considered the following:
|
| 476 |
+
|
| 477 |
+
- Bias and Fairness
|
| 478 |
+
- VLMs trained on large-scale, real-world text and image data can
|
| 479 |
+
reflect socio-cultural biases embedded in the training material. These
|
| 480 |
+
models underwent careful scrutiny, input data pre-processing described
|
| 481 |
+
and posterior evaluations reported in this card.
|
| 482 |
+
- Misinformation and Misuse
|
| 483 |
+
- VLMs can be misused to generate text that is false, misleading,
|
| 484 |
+
or harmful.
|
| 485 |
+
- Guidelines are provided for responsible use with the model, see the
|
| 486 |
+
[Responsible Generative AI Toolkit][rai-toolkit].
|
| 487 |
+
- Transparency and Accountability:
|
| 488 |
+
- This model card summarizes details on the models' architecture,
|
| 489 |
+
capabilities, limitations, and evaluation processes.
|
| 490 |
+
- A responsibly developed open model offers the opportunity to
|
| 491 |
+
share innovation by making VLM technology accessible to developers and
|
| 492 |
+
researchers across the AI ecosystem.
|
| 493 |
+
|
| 494 |
+
Risks identified and mitigations:
|
| 495 |
+
|
| 496 |
+
- **Perpetuation of biases**: It's encouraged to perform continuous
|
| 497 |
+
monitoring (using evaluation metrics, human review) and the exploration of
|
| 498 |
+
de-biasing techniques during model training, fine-tuning, and other use
|
| 499 |
+
cases.
|
| 500 |
+
- **Generation of harmful content**: Mechanisms and guidelines for content
|
| 501 |
+
safety are essential. Developers are encouraged to exercise caution and
|
| 502 |
+
implement appropriate content safety safeguards based on their specific
|
| 503 |
+
product policies and application use cases.
|
| 504 |
+
- **Misuse for malicious purposes**: Technical limitations and developer
|
| 505 |
+
and end-user education can help mitigate against malicious applications of
|
| 506 |
+
VLMs. Educational resources and reporting mechanisms for users to flag
|
| 507 |
+
misuse are provided. Prohibited uses of Gemma models are outlined in the
|
| 508 |
+
[Gemma Prohibited Use Policy][prohibited-use].
|
| 509 |
+
- **Privacy violations**: Models were trained on data filtered for removal
|
| 510 |
+
of certain personal information and other sensitive data. Developers are
|
| 511 |
+
encouraged to adhere to privacy regulations with privacy-preserving
|
| 512 |
+
techniques.
|
| 513 |
+
|
| 514 |
+
### Benefits
|
| 515 |
+
|
| 516 |
+
At the time of release, this family of models provides high-performance open
|
| 517 |
+
vision-language model implementations designed from the ground up for
|
| 518 |
+
responsible AI development compared to similarly sized models.
|
| 519 |
+
|
| 520 |
+
Using the benchmark evaluation metrics described in this document, these models
|
| 521 |
+
have shown to provide superior performance to other, comparably-sized open model
|
| 522 |
+
alternatives.
|
| 523 |
+
|
| 524 |
+
[g3-tech-report]: https://goo.gle/Gemma3Report
|
| 525 |
+
[rai-toolkit]: https://ai.google.dev/responsible
|
| 526 |
+
[kaggle-gemma]: https://www.kaggle.com/models/google/gemma-3
|
| 527 |
+
[vertex-mg-gemma3]: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemma3
|
| 528 |
+
[terms]: https://ai.google.dev/gemma/terms
|
| 529 |
+
[safety-policies]: https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf
|
| 530 |
+
[prohibited-use]: https://ai.google.dev/gemma/prohibited_use_policy
|
| 531 |
+
[tpu]: https://cloud.google.com/tpu/docs/intro-to-tpu
|
| 532 |
+
[sustainability]: https://sustainability.google/operating-sustainably/
|
| 533 |
+
[jax]: https://github.com/jax-ml/jax
|
| 534 |
+
[ml-pathways]: https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/
|
| 535 |
+
[sustainability]: https://sustainability.google/operating-sustainably/
|
| 536 |
+
[gemini-2-paper]: https://arxiv.org/abs/2312.11805
|
chat_template.json
CHANGED
|
@@ -1,3 +1 @@
|
|
| 1 |
-
{
|
| 2 |
-
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
|
| 3 |
-
}
|
|
|
|
| 1 |
+
{"chat_template": "{{ bos_token }}\n{%- if tools %}\n {% set tools_json = tools %}\n You have access to the following tools:\n {% for tool in tools_json %}\n > Tool Name: {{ tool.name }}\n Tool Description: {{ tool.description }}\n Tool Args:\n {% if tool.parameters.properties %}\n {% for prop_name, prop_data in tool.parameters.properties.items() %}\n - {{ prop_name }} ({{ prop_data.type }}, {% if prop_name in tool.parameters.required %}required{% else %}optional{% endif %}): {{ prop_data.description }}\n {% endfor %}\n {% else %}\n {% endif %}\n\n {% endfor %}, where each item should be object\n Use the following format if using a tool:\n ```\n Action: tool name (one of [{{ tools_json | map(attribute='name') | join(', ') }}])\n Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{\"input\": \"hello world\", \"num_beams\": 5}```)\n ```\n {{'\n'}}\n{%- endif -%}\n\n\n{%- for message in messages -%}\n {%- if message['from'] == \"human\" -%}\n {%- if message['value'] is string -%}\n {{'<start_of_turn>user\n' + message['value'] + '<end_of_turn>\n'}}\n {%- elif message['value'] is iterable -%}\n {{'<start_of_turn>user\n'}}\n {%- for item in message['value'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {{'<end_of_turn>\n'}}\n {%- endfor -%}\n {%- endif -%}\n {%- elif message['from'] == \"gpt\" -%}\n {{'<start_of_turn>model\n' + message['value'] + '<end_of_turn>\n'}} \n {%- elif message['from'] == \"function_call\" -%}\n {%- set func_call = message['value'] -%}\n {{'<start_of_turn>model\n' + 'Action:' + func_call.name + '\n' + 'Action Input:' + func_call.arguments + '<end_of_turn>\n'}} \n {%- elif message['from'] == \"observation\" -%}\n {{'<start_of_turn>tool\n' + message['value'] + '<end_of_turn>\n'}} \n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"}
|
|
|
|
|
|
config.json
CHANGED
|
@@ -70,7 +70,7 @@
|
|
| 70 |
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 71 |
"lora_alpha": 16,
|
| 72 |
"r": 8,
|
| 73 |
-
"use_rslora":
|
| 74 |
},
|
| 75 |
"text_config": {
|
| 76 |
"attention_bias": false,
|
|
|
|
| 70 |
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 71 |
"lora_alpha": 16,
|
| 72 |
"r": 8,
|
| 73 |
+
"use_rslora": false
|
| 74 |
},
|
| 75 |
"text_config": {
|
| 76 |
"attention_bias": false,
|
deploy/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
deploy/log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
deploy/main.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from transformers import AutoProcessor, AutoModel
|
| 7 |
+
import torch, torchaudio
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
from rapidfuzz import process, fuzz
|
| 12 |
+
from pypinyin import pinyin, Style
|
| 13 |
+
|
| 14 |
+
def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50):
|
| 15 |
+
pinyin_dict = {}
|
| 16 |
+
for location in location_dict:
|
| 17 |
+
pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)])
|
| 18 |
+
pinyin_dict[pinyin_name] = location
|
| 19 |
+
|
| 20 |
+
user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)])
|
| 21 |
+
|
| 22 |
+
best_match_pinyin = process.extractOne(
|
| 23 |
+
query=user_pinyin_sentence,
|
| 24 |
+
choices=list(pinyin_dict.keys()), # 傳入拼音作為搜尋目標
|
| 25 |
+
scorer=fuzz.token_set_ratio,
|
| 26 |
+
score_cutoff=score_cutoff
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if best_match_pinyin:
|
| 30 |
+
best_pinyin_name = best_match_pinyin[0]
|
| 31 |
+
corrected_location_name = pinyin_dict[best_pinyin_name]
|
| 32 |
+
|
| 33 |
+
best_user_substring = None
|
| 34 |
+
max_substring_score = 0
|
| 35 |
+
|
| 36 |
+
for i in range(len(user_input_sentence)):
|
| 37 |
+
for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)):
|
| 38 |
+
substring = user_input_sentence[i:j]
|
| 39 |
+
|
| 40 |
+
score = fuzz.ratio(substring, corrected_location_name)
|
| 41 |
+
|
| 42 |
+
if score > max_substring_score:
|
| 43 |
+
max_substring_score = score
|
| 44 |
+
best_user_substring = substring
|
| 45 |
+
|
| 46 |
+
if best_user_substring and max_substring_score > score_cutoff:
|
| 47 |
+
return user_input_sentence.replace(best_user_substring, corrected_location_name, 1)
|
| 48 |
+
else:
|
| 49 |
+
return user_input_sentence
|
| 50 |
+
return user_input_sentence
|
| 51 |
+
|
| 52 |
+
class InferenceClass:
|
| 53 |
+
def __init__(self,model_id):
|
| 54 |
+
self.model = AutoModel.from_pretrained(
|
| 55 |
+
model_id, device_map="cuda",
|
| 56 |
+
torch_dtype=torch.bfloat16,
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
attn_implementation="eager"
|
| 59 |
+
).eval()
|
| 60 |
+
|
| 61 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 62 |
+
model_id, trust_remote_code=True
|
| 63 |
+
)
|
| 64 |
+
self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\n','').\
|
| 65 |
+
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
|
| 66 |
+
replace('!','')
|
| 67 |
+
def call_gpt(self,inputs_tensor):
|
| 68 |
+
with torch.inference_mode():
|
| 69 |
+
inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor}
|
| 70 |
+
generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
| 71 |
+
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
|
| 72 |
+
model_output = self.processor.batch_decode(
|
| 73 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 74 |
+
)[0]
|
| 75 |
+
return model_output
|
| 76 |
+
def call_function_fake(self,messages=[],obs=""):
|
| 77 |
+
messages.append({'from': 'observation', 'value': obs})
|
| 78 |
+
return messages
|
| 79 |
+
def generate(self,chat_history,tools="",audio_path=None):
|
| 80 |
+
'''
|
| 81 |
+
input:
|
| 82 |
+
audio_path : str
|
| 83 |
+
chat_history : dict
|
| 84 |
+
return:
|
| 85 |
+
model_output : dict
|
| 86 |
+
'''
|
| 87 |
+
chat_history = copy.deepcopy(chat_history)
|
| 88 |
+
if type(audio_path)!=type(None):
|
| 89 |
+
chat_history.append({'from': 'human',
|
| 90 |
+
'value': [{'type': 'audio',
|
| 91 |
+
'audio': audio_path}]})
|
| 92 |
+
words_from_poi = []
|
| 93 |
+
for hist in chat_history:
|
| 94 |
+
if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
|
| 95 |
+
tmp = json.loads(hist['value'])
|
| 96 |
+
for i,poi in enumerate(tmp['poi']):
|
| 97 |
+
words_from_poi.append(poi['name'])
|
| 98 |
+
for hist in chat_history:
|
| 99 |
+
if hist['from']=='human' and type(hist['value'])==str:
|
| 100 |
+
hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi)
|
| 101 |
+
elif hist['from']=='function_call' and "arguments" in hist['value'] and 'keyword' in hist['value']["arguments"]:
|
| 102 |
+
hist['value']["arguments"] = eval(hist['value']["arguments"])
|
| 103 |
+
if 'keyword' in hist['value']["arguments"]:
|
| 104 |
+
hist['value']["arguments"]['keyword'] = correct_sentence_with_pinyin(hist['value']["arguments"]['keyword'],words_from_poi)
|
| 105 |
+
hist['value']["arguments"] = str(hist['value']["arguments"])
|
| 106 |
+
# model_input_history = copy.deepcopy(chat_history)
|
| 107 |
+
# num2ch = {1:'一',2:'二',3:'三',4:'四',5:'五',6:'六'}
|
| 108 |
+
# for hist in model_input_history:
|
| 109 |
+
# if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
|
| 110 |
+
# tmp = json.loads(hist['value'])
|
| 111 |
+
# new_poi = []
|
| 112 |
+
# for i,poi in enumerate(tmp['poi']):
|
| 113 |
+
# new_poi.append('第{}個 : '.format(num2ch[i+1])+str(poi))
|
| 114 |
+
# tmp['poi'] = new_poi
|
| 115 |
+
# hist['value'] = json.dumps(tmp, ensure_ascii=False)
|
| 116 |
+
|
| 117 |
+
inputs_text = self.processor.apply_chat_template(
|
| 118 |
+
chat_history, add_generation_prompt=True, tokenize=False,
|
| 119 |
+
return_dict=True, return_tensors="pt", tools=json.loads(tools)
|
| 120 |
+
)
|
| 121 |
+
inputs_tensor = self.processor(text=inputs_text,
|
| 122 |
+
audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None,
|
| 123 |
+
add_special_tokens=False,
|
| 124 |
+
return_tensors='pt'
|
| 125 |
+
)
|
| 126 |
+
model_output = self.call_gpt(inputs_tensor)
|
| 127 |
+
if chat_history[-1]['from']=='observation':
|
| 128 |
+
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)})
|
| 129 |
+
return chat_history
|
| 130 |
+
if ((not ';\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\
|
| 131 |
+
or len(model_output.split(';\n'))<2 ):
|
| 132 |
+
if chat_history[-1]['value']!="抱歉我聽不清楚 能麻煩您再說一次嗎":
|
| 133 |
+
chat_history.append({'from': 'human',
|
| 134 |
+
'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}),
|
| 135 |
+
chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'})
|
| 136 |
+
return chat_history
|
| 137 |
+
output_t,output_o = model_output.split(';\n')[:2]
|
| 138 |
+
output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o)
|
| 139 |
+
chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi)
|
| 140 |
+
if 'Action:' in output_o and 'ActionInput:' in output_o: # function calling
|
| 141 |
+
function_name,function_arg = output_o.split('ActionInput:')
|
| 142 |
+
function_name = function_name.replace('Action:','')
|
| 143 |
+
if "keyword" in function_arg:
|
| 144 |
+
function_arg = json.loads(function_arg)
|
| 145 |
+
if "keyword" in function_arg:
|
| 146 |
+
function_arg["keyword"] = correct_sentence_with_pinyin(function_arg["keyword"],words_from_poi)
|
| 147 |
+
chat_history.append({'from': 'function_call', 'value': {"name": function_name, "arguments": str(function_arg)}})
|
| 148 |
+
else: # gpt response
|
| 149 |
+
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)})
|
| 150 |
+
return chat_history
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
model_id = "/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni"
|
| 155 |
+
pipeline = InferenceClass(model_id)
|
| 156 |
+
app = FastAPI(
|
| 157 |
+
title="Audio LLM API",
|
| 158 |
+
description="An API that accepts an audio file and a list of dictionaries.",
|
| 159 |
+
)
|
| 160 |
+
import json
|
| 161 |
+
dataset = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))
|
| 162 |
+
tools = dataset[0]['tools']
|
| 163 |
+
|
| 164 |
+
@app.post("/audio_llm/")
|
| 165 |
+
async def process_audio_and_data(
|
| 166 |
+
audio_file: Optional[UploadFile] = File(None, description="The audio file to be processed."),
|
| 167 |
+
data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.")
|
| 168 |
+
) -> List[Dict[str, Any]]:
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
input_data_list = json.loads(data)
|
| 172 |
+
if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list):
|
| 173 |
+
raise ValueError("The provided data is not a list of dictionaries.")
|
| 174 |
+
|
| 175 |
+
except json.JSONDecodeError:
|
| 176 |
+
raise HTTPException(
|
| 177 |
+
status_code=422,
|
| 178 |
+
detail="Invalid JSON format for 'data' field. Please provide a valid JSON string."
|
| 179 |
+
)
|
| 180 |
+
except ValueError as e:
|
| 181 |
+
raise HTTPException(
|
| 182 |
+
status_code=422,
|
| 183 |
+
detail=str(e)
|
| 184 |
+
)
|
| 185 |
+
temp_file_path=None
|
| 186 |
+
if audio_file:
|
| 187 |
+
temp_file_path = f"./audio_path/temp_{audio_file.filename}"
|
| 188 |
+
with open(temp_file_path, "wb") as buffer:
|
| 189 |
+
buffer.write(await audio_file.read())
|
| 190 |
+
print(f"Audio file saved to {temp_file_path}")
|
| 191 |
+
|
| 192 |
+
output_data = pipeline.generate(input_data_list,tools=tools,audio_path=temp_file_path)
|
| 193 |
+
print(output_data)
|
| 194 |
+
return output_data
|
| 195 |
+
|
| 196 |
+
# uvicorn main:app --host 0.0.0.0 --port 8087 --log-level info --workers 1 >> ./log.txt
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
deploy/testapi.ipynb
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 21,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import requests\n",
|
| 10 |
+
"import json\n",
|
| 11 |
+
"import os\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"def callAPI(input_data,audio_file=None,API_URL = \"http://0.0.0.0:8087/audio_llm/\"):\n",
|
| 14 |
+
" json_data_string = json.dumps(input_data)\n",
|
| 15 |
+
"\n",
|
| 16 |
+
" files=None\n",
|
| 17 |
+
" if audio_file:\n",
|
| 18 |
+
" files = {\n",
|
| 19 |
+
" 'audio_file': (os.path.basename(audio_file), open(audio_file, 'rb'), 'audio/wav')\n",
|
| 20 |
+
" }\n",
|
| 21 |
+
" data = {\n",
|
| 22 |
+
" 'data': json_data_string\n",
|
| 23 |
+
" }\n",
|
| 24 |
+
" try:\n",
|
| 25 |
+
" response = requests.post(API_URL, files=files, data=data)\n",
|
| 26 |
+
" if response.status_code == 200:\n",
|
| 27 |
+
" processed_data = response.json()\n",
|
| 28 |
+
" else:\n",
|
| 29 |
+
" print(f\"請求失敗,狀態碼:{response.status_code}\")\n",
|
| 30 |
+
" print(f\"錯誤訊息:{response.text}\")\n",
|
| 31 |
+
"\n",
|
| 32 |
+
" except requests.exceptions.ConnectionError as e:\n",
|
| 33 |
+
" print(f\"無法連線到伺服器。請確認 FastAPI 應用程式正在執行。錯誤:{e}\")\n",
|
| 34 |
+
" return response.json()"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 22,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"import json\n",
|
| 44 |
+
"data = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 23,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [
|
| 52 |
+
{
|
| 53 |
+
"name": "stdout",
|
| 54 |
+
"output_type": "stream",
|
| 55 |
+
"text": [
|
| 56 |
+
"[{'from': 'human', 'value': '我在找台東太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台東太平逸境'}\"}}]\n",
|
| 57 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}]\n",
|
| 58 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}]\n",
|
| 59 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}]\n",
|
| 60 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡���?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}]\n",
|
| 61 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}]\n",
|
| 62 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}]\n",
|
| 63 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市��平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '這裡好像不對'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}]\n",
|
| 64 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '這裡好像不對'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '我有印象他們有大片玻璃門面'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區大片玻璃門面'}\"}}]\n",
|
| 65 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到���台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '這裡好像不對'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '我有印象他們有大片玻璃門面'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區大片玻璃門面'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"雞丁炸舖\", \"address\": \"台東市中華路一段210號\", \"latitude\": 22.7575, \"longitude\": 121.1457, \"fromMemory\": false}, {\"name\": \"即刻炸\", \"address\": \"桃園市桃園區中正路400號\", \"latitude\": 24.9933, \"longitude\": 121.3121, \"fromMemory\": false}, {\"name\": \"BOOM炸雞\", \"address\": \"台南市永康區中華路88號\", \"latitude\": 23.0349, \"longitude\": 120.25, \"fromMemory\": false}, {\"name\": \"雞咔咔速食站\", \"address\": \"台中市南區建成路160號\", \"latitude\": 24.1217, \"longitude\": 120.6852, \"fromMemory\": false}, {\"name\": \"炸雞一下\", \"address\": \"彰化市民生路188號\", \"latitude\": 24.0801, \"longitude\": 120.5429, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到以下與「澳洲大堡道acher韓的區」相關的地點,請問您要導航到哪一個地點?'}]\n",
|
| 66 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區��的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '這裡好像不對'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '我有印象他們有大片玻璃門面'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區大片玻璃門面'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"雞丁炸舖\", \"address\": \"台東市中華路一段210號\", \"latitude\": 22.7575, \"longitude\": 121.1457, \"fromMemory\": false}, {\"name\": \"即刻炸\", \"address\": \"桃園市桃園區中正路400號\", \"latitude\": 24.9933, \"longitude\": 121.3121, \"fromMemory\": false}, {\"name\": \"BOOM炸雞\", \"address\": \"台南市永康區中華路88號\", \"latitude\": 23.0349, \"longitude\": 120.25, \"fromMemory\": false}, {\"name\": \"雞咔咔速食站\", \"address\": \"台中市南區建成路160號\", \"latitude\": 24.1217, \"longitude\": 120.6852, \"fromMemory\": false}, {\"name\": \"炸雞一下\", \"address\": \"彰化市民生路188號\", \"latitude\": 24.0801, \"longitude\": 120.5429, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到以下與「澳洲大堡道acher韓的區」相關的地點,請問您要導航到哪一個地點?'}, {'from': 'human', 'value': '還是算了,去比較好'}, {'from': 'gpt', 'value': '已為您取消導航清單'}]\n",
|
| 67 |
+
"[{'from': 'human', 'value': '我在找台中太平逸境'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '台中太平逸境'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'}, {'from': 'human', 'value': '去第一個'}, {'from': 'function_call', 'value': {'name': 'nav_start', 'arguments': \"{'keyword': '台中太平逸境', 'name': '台中太平逸境', 'address': '台中市太平區新仁路一段88號', 'latitude': 24.109912, 'longitude': 120.737201, 'fromMemory': False}\"}}, {'from': 'observation', 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'}, {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'}, {'from': 'human', 'value': '請帶我去澳洲大堡道acher韓的區'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'}, {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'}, {'from': 'human', 'value': '這好像不像吧'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '這裡好像不對'}, {'from': 'gpt', 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道acher韓的區」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'}, {'from': 'human', 'value': '我有印象他們有大片玻璃門面'}, {'from': 'function_call', 'value': {'name': 'search_and_show_place', 'arguments': \"{'keyword': '澳洲大堡道acher韓的區大片玻璃門面'}\"}}, {'from': 'observation', 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"雞丁炸舖\", \"address\": \"台東市中華路一段210號\", \"latitude\": 22.7575, \"longitude\": 121.1457, \"fromMemory\": false}, {\"name\": \"即刻炸\", \"address\": \"桃園市桃園區中正路400號\", \"latitude\": 24.9933, \"longitude\": 121.3121, \"fromMemory\": false}, {\"name\": \"BOOM炸雞\", \"address\": \"台南市永康區中華路88號\", \"latitude\": 23.0349, \"longitude\": 120.25, \"fromMemory\": false}, {\"name\": \"雞咔咔速食站\", \"address\": \"台中市南區建成路160號\", \"latitude\": 24.1217, \"longitude\": 120.6852, \"fromMemory\": false}, {\"name\": \"炸雞一下\", \"address\": \"彰化市民���路188號\", \"latitude\": 24.0801, \"longitude\": 120.5429, \"fromMemory\": false}]}'}, {'from': 'gpt', 'value': '我找到以下與「澳洲大堡道acher韓的區」相關的地點,請問您要導航到哪一個地點?'}, {'from': 'human', 'value': '還是算了,去比較好'}, {'from': 'gpt', 'value': '已為您取消導航清單'}, {'from': 'human', 'value': '還是算了,不去比較好'}, {'from': 'gpt', 'value': '已為您取消導航清單'}]\n"
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
"source": [
|
| 72 |
+
"for all_conv in data:\n",
|
| 73 |
+
" history = []\n",
|
| 74 |
+
" tools = all_conv['tools']\n",
|
| 75 |
+
" for idx,conv in enumerate(all_conv['conversations']):\n",
|
| 76 |
+
" if conv['from']=='function_call' or conv['from']=='gpt':continue\n",
|
| 77 |
+
" elif conv['from']=='human':\n",
|
| 78 |
+
" history = callAPI(history,audio_file=conv['audio_path'])\n",
|
| 79 |
+
" elif conv['from']=='observation':\n",
|
| 80 |
+
" history.append(conv)\n",
|
| 81 |
+
" history = callAPI(history)\n",
|
| 82 |
+
" print(history)\n",
|
| 83 |
+
" break"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": 20,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [
|
| 91 |
+
{
|
| 92 |
+
"data": {
|
| 93 |
+
"text/plain": [
|
| 94 |
+
"[{'from': 'human', 'value': '我在找台東太平逸境'},\n",
|
| 95 |
+
" {'from': 'function_call',\n",
|
| 96 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 97 |
+
" 'arguments': '{\"keyword\":\"台東太平逸境\"}'}},\n",
|
| 98 |
+
" {'from': 'observation',\n",
|
| 99 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'},\n",
|
| 100 |
+
" {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'},\n",
|
| 101 |
+
" {'from': 'human', 'value': '去第一個'},\n",
|
| 102 |
+
" {'from': 'function_call',\n",
|
| 103 |
+
" 'value': {'name': 'nav_start',\n",
|
| 104 |
+
" 'arguments': '{\"keyword\":\"台東太平逸境\",\"name\":\"台中太平逸境\",\"address\":\"台中市太平區新仁路一段88號\",\"latitude\":24.109912,\"longitude\":120.737201,\"fromMemory\":false}'}},\n",
|
| 105 |
+
" {'from': 'observation',\n",
|
| 106 |
+
" 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'},\n",
|
| 107 |
+
" {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'},\n",
|
| 108 |
+
" {'from': 'human', 'value': '請帶我去澳洲大堡道科韓的區政府店'},\n",
|
| 109 |
+
" {'from': 'function_call',\n",
|
| 110 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 111 |
+
" 'arguments': '{\"keyword\":\"澳洲大堡道科韓的區政府店\"}'}},\n",
|
| 112 |
+
" {'from': 'observation',\n",
|
| 113 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'},\n",
|
| 114 |
+
" {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'},\n",
|
| 115 |
+
" {'from': 'human', 'value': '這好像不像吧'},\n",
|
| 116 |
+
" {'from': 'gpt',\n",
|
| 117 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道科韓的區政府店」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 118 |
+
" {'from': 'human', 'value': '這裡好像不對'},\n",
|
| 119 |
+
" {'from': 'gpt',\n",
|
| 120 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道科韓的區政府店」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 121 |
+
" {'from': 'human', 'value': '我有印象他們有大片玻璃門面'},\n",
|
| 122 |
+
" {'from': 'function_call',\n",
|
| 123 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 124 |
+
" 'arguments': '{\"keyword\":\"澳洲大堡道科韓的區政府店大片玻璃門面\"}'}},\n",
|
| 125 |
+
" {'from': 'observation',\n",
|
| 126 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"雞丁炸舖\", \"address\": \"台東市中華路一段210號\", \"latitude\": 22.7575, \"longitude\": 121.1457, \"fromMemory\": false}, {\"name\": \"即刻炸\", \"address\": \"桃園市桃園區中正路400號\", \"latitude\": 24.9933, \"longitude\": 121.3121, \"fromMemory\": false}, {\"name\": \"BOOM炸雞\", \"address\": \"台南市永康區中華路88號\", \"latitude\": 23.0349, \"longitude\": 120.25, \"fromMemory\": false}, {\"name\": \"雞咔咔速食站\", \"address\": \"台中市南區建成路160號\", \"latitude\": 24.1217, \"longitude\": 120.6852, \"fromMemory\": false}, {\"name\": \"炸雞一下\", \"address\": \"彰化市民生路188號\", \"latitude\": 24.0801, \"longitude\": 120.5429, \"fromMemory\": false}]}'},\n",
|
| 127 |
+
" {'from': 'gpt', 'value': '我找到以下與「澳洲大堡道科韓的區政府店」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 128 |
+
" {'from': 'human', 'value': '還是算了,去比較好'},\n",
|
| 129 |
+
" {'from': 'gpt', 'value': '已為您取消導航清單'},\n",
|
| 130 |
+
" {'from': 'human', 'value': '還是算了,不去比較好'},\n",
|
| 131 |
+
" {'from': 'gpt', 'value': '已為您取消導航清單'}]"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
"execution_count": 20,
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"output_type": "execute_result"
|
| 137 |
+
}
|
| 138 |
+
],
|
| 139 |
+
"source": [
|
| 140 |
+
"history"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": null,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": []
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"metadata": {
|
| 152 |
+
"kernelspec": {
|
| 153 |
+
"display_name": "py10",
|
| 154 |
+
"language": "python",
|
| 155 |
+
"name": "python3"
|
| 156 |
+
},
|
| 157 |
+
"language_info": {
|
| 158 |
+
"codemirror_mode": {
|
| 159 |
+
"name": "ipython",
|
| 160 |
+
"version": 3
|
| 161 |
+
},
|
| 162 |
+
"file_extension": ".py",
|
| 163 |
+
"mimetype": "text/x-python",
|
| 164 |
+
"name": "python",
|
| 165 |
+
"nbconvert_exporter": "python",
|
| 166 |
+
"pygments_lexer": "ipython3",
|
| 167 |
+
"version": "3.10.18"
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
"nbformat": 4,
|
| 171 |
+
"nbformat_minor": 2
|
| 172 |
+
}
|
eval.py
CHANGED
|
@@ -18,14 +18,13 @@ from torch.utils.data import Dataset, DataLoader
|
|
| 18 |
import soundfile as sf
|
| 19 |
import re
|
| 20 |
from pathlib import Path
|
| 21 |
-
|
| 22 |
-
converter = opencc.OpenCC('s2tw.json')
|
| 23 |
normalizer = {
|
| 24 |
"en_us" : EnglishTextNormalizer(),
|
| 25 |
"other" : BasicTextNormalizer()
|
| 26 |
}
|
| 27 |
|
| 28 |
-
model_id = "/
|
| 29 |
revision = "main" #"v1.0"
|
| 30 |
|
| 31 |
model = AutoModel.from_pretrained(
|
|
@@ -146,7 +145,7 @@ class LibriSpeechDataset(BaseAudioDataset):
|
|
| 146 |
subset,
|
| 147 |
split=split,
|
| 148 |
trust_remote_code=True,
|
| 149 |
-
cache_dir=Path("/
|
| 150 |
)
|
| 151 |
|
| 152 |
# (Optional) Audio length Filtering
|
|
@@ -185,7 +184,7 @@ class CommonVoiceDataset(BaseAudioDataset):
|
|
| 185 |
source_lang,
|
| 186 |
split=split,
|
| 187 |
trust_remote_code=True,
|
| 188 |
-
cache_dir=Path("/
|
| 189 |
)
|
| 190 |
def prepare_dataset(batch):
|
| 191 |
"""Function to preprocess the dataset with the .map method"""
|
|
@@ -202,7 +201,7 @@ class CommonVoiceDataset(BaseAudioDataset):
|
|
| 202 |
batch["sentence"] = transcription
|
| 203 |
|
| 204 |
return batch
|
| 205 |
-
self.data
|
| 206 |
|
| 207 |
# (Optional) Audio length Filtering
|
| 208 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
@@ -249,7 +248,7 @@ class FleursDataset(BaseAudioDataset):
|
|
| 249 |
source_lang,
|
| 250 |
split=split,
|
| 251 |
trust_remote_code=True,
|
| 252 |
-
cache_dir=Path("/
|
| 253 |
)
|
| 254 |
def prepare_dataset(batch):
|
| 255 |
import opencc
|
|
@@ -263,7 +262,7 @@ class FleursDataset(BaseAudioDataset):
|
|
| 263 |
|
| 264 |
return batch
|
| 265 |
if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
|
| 266 |
-
self.data
|
| 267 |
|
| 268 |
# (Optional) Audio length Filtering
|
| 269 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
@@ -281,7 +280,7 @@ class FleursDataset(BaseAudioDataset):
|
|
| 281 |
target_lang,
|
| 282 |
split=split,
|
| 283 |
trust_remote_code=True,
|
| 284 |
-
cache_dir=Path("/
|
| 285 |
)
|
| 286 |
|
| 287 |
source_dict = {item['id']: item for item in self.data}
|
|
@@ -435,6 +434,8 @@ def save_results(results, dataset_name, task, source_lang, target_lang=None, sam
|
|
| 435 |
return filepath
|
| 436 |
|
| 437 |
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
|
|
|
|
|
|
|
| 438 |
task_type = "asr" if is_asr else "translation"
|
| 439 |
eval_lang = source_lang if is_asr else target_lang
|
| 440 |
if eval_lang in normalizer:
|
|
@@ -482,7 +483,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
|
|
| 482 |
"prediction": converter.convert(prediction)
|
| 483 |
}
|
| 484 |
sample_results.append(sample_result)
|
| 485 |
-
|
| 486 |
if (batch_idx + 1) % 10 == 0:
|
| 487 |
temp_results = []
|
| 488 |
|
|
@@ -496,6 +497,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
|
|
| 496 |
else:
|
| 497 |
temp_item = item.copy()
|
| 498 |
try:
|
|
|
|
| 499 |
ref = eval_normalizer(item["reference"])
|
| 500 |
pred = eval_normalizer(item["prediction"])
|
| 501 |
|
|
@@ -506,7 +508,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
|
|
| 506 |
|
| 507 |
metrics = {
|
| 508 |
"bleu": utt_bleu,
|
| 509 |
-
"cer":
|
| 510 |
"wer": utt_wer
|
| 511 |
}
|
| 512 |
|
|
@@ -544,7 +546,7 @@ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size
|
|
| 544 |
|
| 545 |
item.update({
|
| 546 |
"bleu": utt_bleu,
|
| 547 |
-
"cer":
|
| 548 |
"wer": utt_wer
|
| 549 |
})
|
| 550 |
|
|
@@ -581,7 +583,7 @@ if __name__ == "__main__":
|
|
| 581 |
]
|
| 582 |
|
| 583 |
num_samples = -1
|
| 584 |
-
batch_size =
|
| 585 |
|
| 586 |
for source_lang, target_lang in zip(source_languages, target_languages):
|
| 587 |
print(f"\n===== {source_lang[0]} ASR =====")
|
|
@@ -590,21 +592,12 @@ if __name__ == "__main__":
|
|
| 590 |
|
| 591 |
datasets = []
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
commonvoice_speech_tw = CommonVoiceDataset(
|
| 596 |
processor=processor,
|
| 597 |
source_lang="zh-TW",
|
| 598 |
split=split
|
| 599 |
)
|
| 600 |
datasets.append(commonvoice_speech_tw)
|
| 601 |
-
fleurs = FleursDataset(
|
| 602 |
-
processor=processor,
|
| 603 |
-
split=split,
|
| 604 |
-
source_lang="en_us", # English
|
| 605 |
-
mode="asr"
|
| 606 |
-
)
|
| 607 |
-
datasets.append(fleurs)
|
| 608 |
|
| 609 |
# Libri Speech Clean ASR mode (English -> English text)
|
| 610 |
# libri_speech_clean = LibriSpeechDataset(
|
|
@@ -623,7 +616,13 @@ if __name__ == "__main__":
|
|
| 623 |
# datasets.append(libri_speech_other)
|
| 624 |
|
| 625 |
# Fleurs ASR mode (English -> English text)
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
for dataset in datasets:
|
| 629 |
# ASR
|
|
|
|
| 18 |
import soundfile as sf
|
| 19 |
import re
|
| 20 |
from pathlib import Path
|
| 21 |
+
|
|
|
|
| 22 |
normalizer = {
|
| 23 |
"en_us" : EnglishTextNormalizer(),
|
| 24 |
"other" : BasicTextNormalizer()
|
| 25 |
}
|
| 26 |
|
| 27 |
+
model_id = "/home/jeff/codes/llm/InCar/gemma-3-4b-it-omni"
|
| 28 |
revision = "main" #"v1.0"
|
| 29 |
|
| 30 |
model = AutoModel.from_pretrained(
|
|
|
|
| 145 |
subset,
|
| 146 |
split=split,
|
| 147 |
trust_remote_code=True,
|
| 148 |
+
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
|
| 149 |
)
|
| 150 |
|
| 151 |
# (Optional) Audio length Filtering
|
|
|
|
| 184 |
source_lang,
|
| 185 |
split=split,
|
| 186 |
trust_remote_code=True,
|
| 187 |
+
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
|
| 188 |
)
|
| 189 |
def prepare_dataset(batch):
|
| 190 |
"""Function to preprocess the dataset with the .map method"""
|
|
|
|
| 201 |
batch["sentence"] = transcription
|
| 202 |
|
| 203 |
return batch
|
| 204 |
+
self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 205 |
|
| 206 |
# (Optional) Audio length Filtering
|
| 207 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
|
|
| 248 |
source_lang,
|
| 249 |
split=split,
|
| 250 |
trust_remote_code=True,
|
| 251 |
+
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
|
| 252 |
)
|
| 253 |
def prepare_dataset(batch):
|
| 254 |
import opencc
|
|
|
|
| 262 |
|
| 263 |
return batch
|
| 264 |
if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
|
| 265 |
+
self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 266 |
|
| 267 |
# (Optional) Audio length Filtering
|
| 268 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
|
|
| 280 |
target_lang,
|
| 281 |
split=split,
|
| 282 |
trust_remote_code=True,
|
| 283 |
+
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
|
| 284 |
)
|
| 285 |
|
| 286 |
source_dict = {item['id']: item for item in self.data}
|
|
|
|
| 434 |
return filepath
|
| 435 |
|
| 436 |
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
|
| 437 |
+
import opencc
|
| 438 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 439 |
task_type = "asr" if is_asr else "translation"
|
| 440 |
eval_lang = source_lang if is_asr else target_lang
|
| 441 |
if eval_lang in normalizer:
|
|
|
|
| 483 |
"prediction": converter.convert(prediction)
|
| 484 |
}
|
| 485 |
sample_results.append(sample_result)
|
| 486 |
+
|
| 487 |
if (batch_idx + 1) % 10 == 0:
|
| 488 |
temp_results = []
|
| 489 |
|
|
|
|
| 497 |
else:
|
| 498 |
temp_item = item.copy()
|
| 499 |
try:
|
| 500 |
+
|
| 501 |
ref = eval_normalizer(item["reference"])
|
| 502 |
pred = eval_normalizer(item["prediction"])
|
| 503 |
|
|
|
|
| 508 |
|
| 509 |
metrics = {
|
| 510 |
"bleu": utt_bleu,
|
| 511 |
+
"cer": utt_cer,
|
| 512 |
"wer": utt_wer
|
| 513 |
}
|
| 514 |
|
|
|
|
| 546 |
|
| 547 |
item.update({
|
| 548 |
"bleu": utt_bleu,
|
| 549 |
+
"cer": utt_cer,
|
| 550 |
"wer": utt_wer
|
| 551 |
})
|
| 552 |
|
|
|
|
| 583 |
]
|
| 584 |
|
| 585 |
num_samples = -1
|
| 586 |
+
batch_size = 2
|
| 587 |
|
| 588 |
for source_lang, target_lang in zip(source_languages, target_languages):
|
| 589 |
print(f"\n===== {source_lang[0]} ASR =====")
|
|
|
|
| 592 |
|
| 593 |
datasets = []
|
| 594 |
|
|
|
|
|
|
|
| 595 |
commonvoice_speech_tw = CommonVoiceDataset(
|
| 596 |
processor=processor,
|
| 597 |
source_lang="zh-TW",
|
| 598 |
split=split
|
| 599 |
)
|
| 600 |
datasets.append(commonvoice_speech_tw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
# Libri Speech Clean ASR mode (English -> English text)
|
| 603 |
# libri_speech_clean = LibriSpeechDataset(
|
|
|
|
| 616 |
# datasets.append(libri_speech_other)
|
| 617 |
|
| 618 |
# Fleurs ASR mode (English -> English text)
|
| 619 |
+
fleurs = FleursDataset(
|
| 620 |
+
processor=processor,
|
| 621 |
+
split=split,
|
| 622 |
+
source_lang="en_us", # English
|
| 623 |
+
mode="asr"
|
| 624 |
+
)
|
| 625 |
+
datasets.append(fleurs)
|
| 626 |
|
| 627 |
for dataset in datasets:
|
| 628 |
# ASR
|
eval_multiturn_textonly.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from urllib.request import urlopen
|
| 3 |
+
import soundfile
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 14 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 15 |
+
import sacrebleu
|
| 16 |
+
from jiwer import cer, wer
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import opencc
|
| 22 |
+
from ASRDataset import *
|
| 23 |
+
|
| 24 |
+
# converter = opencc.OpenCC('s2tw.json')
|
| 25 |
+
|
| 26 |
+
model_id = "./"
|
| 27 |
+
revision = "main" #"v1.0"
|
| 28 |
+
|
| 29 |
+
processor = AutoProcessor.from_pretrained(
|
| 30 |
+
model_id, revision = revision, trust_remote_code=True
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 34 |
+
# os.makedirs(results_dir, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0):
|
| 37 |
+
res = {'label':[],"pred":[],'cer':[]}
|
| 38 |
+
func_error = 0
|
| 39 |
+
total_func_call = 0
|
| 40 |
+
total_error = 0
|
| 41 |
+
all_output_text = []
|
| 42 |
+
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
|
| 43 |
+
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
|
| 44 |
+
replace('.','').replace('!','')
|
| 45 |
+
for batch_idx, batch in enumerate(tqdm(dataloader)):
|
| 46 |
+
if batch_idx<=start_idx:continue
|
| 47 |
+
batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
|
| 48 |
+
try:
|
| 49 |
+
with torch.inference_mode():
|
| 50 |
+
if not with_input_mode: batch.pop('input_modes')
|
| 51 |
+
generate_ids = model.generate(**batch,
|
| 52 |
+
max_new_tokens=256,
|
| 53 |
+
temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
|
| 54 |
+
)
|
| 55 |
+
batch_inputs = processor.batch_decode(
|
| 56 |
+
generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True,
|
| 57 |
+
clean_up_tokenization_spaces=False
|
| 58 |
+
)
|
| 59 |
+
batch_predictions = processor.batch_decode(
|
| 60 |
+
generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True,
|
| 61 |
+
clean_up_tokenization_spaces=False
|
| 62 |
+
)
|
| 63 |
+
batch_references = processor.batch_decode(
|
| 64 |
+
batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 65 |
+
)
|
| 66 |
+
for inp,label,output in zip(batch_inputs,batch_references,batch_predictions):
|
| 67 |
+
|
| 68 |
+
cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2))
|
| 69 |
+
res['label'].append(batch_references)
|
| 70 |
+
res['pred'].append(batch_predictions)
|
| 71 |
+
res['cer'].append(cer_o)
|
| 72 |
+
all_output_text.append({
|
| 73 |
+
'input':inp,
|
| 74 |
+
'label':label,
|
| 75 |
+
'output':output,
|
| 76 |
+
'cer':cer_o,
|
| 77 |
+
})
|
| 78 |
+
if 'Action:' in label:
|
| 79 |
+
func_error+=(remove_sign(label)!=remove_sign(output))
|
| 80 |
+
total_func_call+=1
|
| 81 |
+
if batch_idx%100==0:
|
| 82 |
+
with open(save_path,'w', encoding='utf-8') as f:
|
| 83 |
+
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
|
| 84 |
+
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
|
| 85 |
+
total_error = sum(a['cer']!=0 for a in all_output_text)
|
| 86 |
+
print('total',len(all_output_text))
|
| 87 |
+
print('total_error & rate',total_error,total_error/len(all_output_text))
|
| 88 |
+
print('avg_cer',avg_cer)
|
| 89 |
+
print('total_func_call',total_func_call)
|
| 90 |
+
print('func_error & rate',func_error,',',func_error/total_func_call)
|
| 91 |
+
except:
|
| 92 |
+
print("error at ",batch_idx)
|
| 93 |
+
time.sleep(2)
|
| 94 |
+
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
|
| 95 |
+
total_error = sum(a['cer']!=0 for a in all_output_text)
|
| 96 |
+
print('total',len(all_output_text))
|
| 97 |
+
print('total_error & rate',total_error,total_error/len(all_output_text))
|
| 98 |
+
print('avg_cer',avg_cer)
|
| 99 |
+
print('total_func_call',total_func_call)
|
| 100 |
+
print('func_error & rate',func_error,',',func_error/total_func_call)
|
| 101 |
+
with open(save_path,'w', encoding='utf-8') as f:
|
| 102 |
+
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
|
| 103 |
+
return res,all_output_text
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json')
|
| 107 |
+
ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json')
|
| 108 |
+
ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
|
| 109 |
+
nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 114 |
+
from PIL import Image
|
| 115 |
+
import requests
|
| 116 |
+
import torch
|
| 117 |
+
|
| 118 |
+
model_id_org = "google/gemma-3-4b-it"
|
| 119 |
+
|
| 120 |
+
model_org = Gemma3ForConditionalGeneration.from_pretrained(
|
| 121 |
+
model_id_org, device_map="auto",attn_implementation="eager"
|
| 122 |
+
).eval()
|
| 123 |
+
|
| 124 |
+
from peft import PeftModel
|
| 125 |
+
model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270')
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16]))
|
| 130 |
+
res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16]))
|
| 131 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 2,
|
| 3 |
+
"cache_implementation": "hybrid",
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
1,
|
| 7 |
+
106
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 0,
|
| 10 |
+
"top_k": 64,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.51.3"
|
| 13 |
+
}
|
model-00001-of-00003.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c729778bd29ddfd0e5a97b55113d4d0b70ae902159c36601234824502556777e
|
| 3 |
+
size 4983859800
|
model-00002-of-00003.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:101b7c0d33ed4b221f1c043768fc805b8d183f117746ef782df74018094ce0d9
|
| 3 |
+
size 4997727608
|
model-00003-of-00003.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28aa17a269742a7258379fe548d426d2418a98d8535590f2f6cc63767883a0f9
|
| 3 |
+
size 741700120
|
model.safetensors.index.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
preprocessing_gemma3omni.py
CHANGED
|
@@ -381,24 +381,12 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
| 381 |
if audio is not None:
|
| 382 |
full_audio_sequences = []
|
| 383 |
audio_inputs = self.feature_extractor(audio)
|
| 384 |
-
def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
|
| 385 |
-
parts = prompt.split(boa_token)
|
| 386 |
-
result = ""
|
| 387 |
-
for i in range(len(parts) - 1):
|
| 388 |
-
result += parts[i]
|
| 389 |
-
if i < len(audio_sequences):
|
| 390 |
-
result += audio_sequences[i]
|
| 391 |
-
else:
|
| 392 |
-
result += boa_token
|
| 393 |
-
result += parts[-1]
|
| 394 |
-
return result
|
| 395 |
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
| 396 |
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
| 397 |
full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
|
| 398 |
full_audio_sequences.append(full_audio_sequence)
|
| 399 |
|
| 400 |
-
text = [
|
| 401 |
-
#text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
| 402 |
|
| 403 |
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
| 404 |
|
|
|
|
| 381 |
if audio is not None:
|
| 382 |
full_audio_sequences = []
|
| 383 |
audio_inputs = self.feature_extractor(audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
| 385 |
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
| 386 |
full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
|
| 387 |
full_audio_sequences.append(full_audio_sequence)
|
| 388 |
|
| 389 |
+
text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
|
|
|
| 390 |
|
| 391 |
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
| 392 |
|
run.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RDMAV_FORK_SAFE=1
|
| 2 |
+
export NCCL_P2P_DISABLE="1"
|
| 3 |
+
export NCCL_IB_DISABLE="1"
|
| 4 |
+
python training.py
|
template.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"chat_template": "{{ bos_token }}\n{%- if tools %}\n {% set tools_json = tools %}\n You have access to the following tools:\n {% for tool in tools_json %}\n > Tool Name: {{ tool.name }}\n Tool Description: {{ tool.description }}\n Tool Args:\n {% if tool.parameters.properties %}\n {% for prop_name, prop_data in tool.parameters.properties.items() %}\n - {{ prop_name }} ({{ prop_data.type }}, {% if prop_name in tool.parameters.required %}required{% else %}optional{% endif %}): {{ prop_data.description }}\n {% endfor %}\n {% else %}\n {% endif %}\n\n {% endfor %}, where each item should be object\n Use the following format if using a tool:\n ```\n Action: tool name (one of [{{ tools_json | map(attribute='name') | join(', ') }}])\n Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{\"input\": \"hello world\", \"num_beams\": 5}```)\n ```\n {{'\n'}}\n{%- endif -%}\n\n\n{%- for message in messages -%}\n {%- if message['from'] == \"human\" -%}\n {%- if message['value'] is string -%}\n {{'<start_of_turn>user\n' + message['value'] + '<end_of_turn>\n'}}\n {%- elif message['value'] is iterable -%}\n {{'<start_of_turn>user\n'}}\n {%- for item in message['value'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {{'<end_of_turn>\n'}}\n {%- endfor -%}\n {%- endif -%}\n {%- elif message['from'] == \"gpt\" -%}\n {{'<start_of_turn>model\n' + message['value'] + '<end_of_turn>\n'}} \n {%- elif message['from'] == \"function_call\" -%}\n {%- set func_call = message['value'] -%}\n {{'<start_of_turn>model\n' + 'Action:' + func_call.name + '\n' + 'Action Input:' + func_call.arguments + '<end_of_turn>\n'}} \n {%- elif message['from'] == \"observation\" -%}\n {{'<start_of_turn>tool\n' + message['value'] + '<end_of_turn>\n'}} \n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"}
|
test.ipynb
ADDED
|
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 14 |
+
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"from transformers import AutoProcessor, AutoModel\n",
|
| 20 |
+
"import torch\n",
|
| 21 |
+
"import os\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"os.environ['TORCH_USE_CUDA_DSA']=\"1\"\n",
|
| 24 |
+
"os.environ['CUDA_LAUNCH_BLOCKING']=\"1\"\n",
|
| 25 |
+
"os.environ['TORCH_DISABLE_SDPA'] = '1'\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"model_id = \"/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni\"\n",
|
| 28 |
+
"revision = \"main\"\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# model = AutoModel.from_pretrained(\n",
|
| 31 |
+
"# model_id, device_map=\"cuda\", \n",
|
| 32 |
+
"# revision = revision, trust_remote_code=True,\n",
|
| 33 |
+
"# ).eval()\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"processor = AutoProcessor.from_pretrained(\n",
|
| 36 |
+
" model_id, revision = revision, trust_remote_code=True\n",
|
| 37 |
+
")\n"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": 2,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [
|
| 45 |
+
{
|
| 46 |
+
"name": "stderr",
|
| 47 |
+
"output_type": "stream",
|
| 48 |
+
"text": [
|
| 49 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/torchaudio/_backend/utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n",
|
| 50 |
+
" warnings.warn(\n",
|
| 51 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/torchaudio/_backend/ffmpeg.py:88: UserWarning: torio.io._streaming_media_decoder.StreamingMediaDecoder has been deprecated. This deprecation is part of a large refactoring effort to transition TorchAudio into a maintenance phase. The decoding and encoding capabilities of PyTorch for both audio and video are being consolidated into TorchCodec. Please see https://github.com/pytorch/audio/issues/3902 for more information. It will be removed from the 2.9 release. \n",
|
| 52 |
+
" s = torchaudio.io.StreamReader(src, format, None, buffer_size)\n"
|
| 53 |
+
]
|
| 54 |
+
}
|
| 55 |
+
],
|
| 56 |
+
"source": [
|
| 57 |
+
"from ASRDataset import *\n",
|
| 58 |
+
"datasets = MultiturnAudioDataset(processor=processor,json_path='/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json')"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": 3,
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"outputs": [
|
| 66 |
+
{
|
| 67 |
+
"data": {
|
| 68 |
+
"text/plain": [
|
| 69 |
+
"'我在找台中太平逸境'"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
"execution_count": 3,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"output_type": "execute_result"
|
| 75 |
+
}
|
| 76 |
+
],
|
| 77 |
+
"source": [
|
| 78 |
+
"from rapidfuzz import process, fuzz\n",
|
| 79 |
+
"from pypinyin import pinyin, Style\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50):\n",
|
| 82 |
+
" pinyin_dict = {}\n",
|
| 83 |
+
" for location in location_dict:\n",
|
| 84 |
+
" pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)])\n",
|
| 85 |
+
" pinyin_dict[pinyin_name] = location\n",
|
| 86 |
+
"\n",
|
| 87 |
+
" user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)])\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" best_match_pinyin = process.extractOne(\n",
|
| 90 |
+
" query=user_pinyin_sentence,\n",
|
| 91 |
+
" choices=list(pinyin_dict.keys()), # 傳入拼音作為搜尋目標\n",
|
| 92 |
+
" scorer=fuzz.token_set_ratio,\n",
|
| 93 |
+
" score_cutoff=score_cutoff\n",
|
| 94 |
+
" )\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" if best_match_pinyin:\n",
|
| 97 |
+
" best_pinyin_name = best_match_pinyin[0]\n",
|
| 98 |
+
" corrected_location_name = pinyin_dict[best_pinyin_name]\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" best_user_substring = None\n",
|
| 101 |
+
" max_substring_score = 0\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" for i in range(len(user_input_sentence)):\n",
|
| 104 |
+
" for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)):\n",
|
| 105 |
+
" substring = user_input_sentence[i:j]\n",
|
| 106 |
+
" \n",
|
| 107 |
+
" score = fuzz.ratio(substring, corrected_location_name)\n",
|
| 108 |
+
" \n",
|
| 109 |
+
" if score > max_substring_score:\n",
|
| 110 |
+
" max_substring_score = score\n",
|
| 111 |
+
" best_user_substring = substring\n",
|
| 112 |
+
" \n",
|
| 113 |
+
" if best_user_substring and max_substring_score > score_cutoff:\n",
|
| 114 |
+
" return user_input_sentence.replace(best_user_substring, corrected_location_name, 1)\n",
|
| 115 |
+
" else:\n",
|
| 116 |
+
" return user_input_sentence\n",
|
| 117 |
+
" return user_input_sentence\n",
|
| 118 |
+
"large_location_dict = [\n",
|
| 119 |
+
" \"台北太平逸\",\n",
|
| 120 |
+
" \"台中太平逸境\",\n",
|
| 121 |
+
" \"台北信義區\",\n",
|
| 122 |
+
" \"高雄駁二藝術特區\",\n",
|
| 123 |
+
" \"台南安平古堡\",\n",
|
| 124 |
+
" \"台中逢甲夜市\",\n",
|
| 125 |
+
" \"台北101\",\n",
|
| 126 |
+
" \"淡水老街\"\n",
|
| 127 |
+
" ]\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"user_input_1 = \"我在找台東太平逸境\"\n",
|
| 130 |
+
"correct_sentence_with_pinyin(user_input_1, large_location_dict, score_cutoff=50)"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "code",
|
| 135 |
+
"execution_count": 4,
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
+
"source": [
|
| 139 |
+
"import json\n",
|
| 140 |
+
"data = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": 5,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": [
|
| 149 |
+
"import copy\n",
|
| 150 |
+
"class InferenceClass:\n",
|
| 151 |
+
" def __init__(self,model_id):\n",
|
| 152 |
+
" self.model = AutoModel.from_pretrained(\n",
|
| 153 |
+
" model_id, device_map=\"cuda\", \n",
|
| 154 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 155 |
+
" trust_remote_code=True,\n",
|
| 156 |
+
" attn_implementation=\"eager\"\n",
|
| 157 |
+
" ).eval()\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" self.processor = AutoProcessor.from_pretrained(\n",
|
| 160 |
+
" model_id, trust_remote_code=True\n",
|
| 161 |
+
" )\n",
|
| 162 |
+
" self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\\n','').\\\n",
|
| 163 |
+
" replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\\\n",
|
| 164 |
+
" replace('!','')\n",
|
| 165 |
+
" def call_gpt(self,inputs_tensor):\n",
|
| 166 |
+
" with torch.inference_mode():\n",
|
| 167 |
+
" inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor}\n",
|
| 168 |
+
" generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)\n",
|
| 169 |
+
" generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]\n",
|
| 170 |
+
" model_output = self.processor.batch_decode(\n",
|
| 171 |
+
" generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
| 172 |
+
" )[0]\n",
|
| 173 |
+
" return model_output\n",
|
| 174 |
+
" def call_function_fake(self,messages=[],obs=\"\"):\n",
|
| 175 |
+
" messages.append({'from': 'observation', 'value': obs})\n",
|
| 176 |
+
" return messages\n",
|
| 177 |
+
" def generate(self,chat_history,tools=\"\",audio_path=None):\n",
|
| 178 |
+
" '''\n",
|
| 179 |
+
" input:\n",
|
| 180 |
+
" audio_path : str\n",
|
| 181 |
+
" chat_history : dict\n",
|
| 182 |
+
" return:\n",
|
| 183 |
+
" model_output : dict\n",
|
| 184 |
+
" '''\n",
|
| 185 |
+
" chat_history = copy.deepcopy(chat_history)\n",
|
| 186 |
+
" if type(audio_path)!=type(None):\n",
|
| 187 |
+
" chat_history.append({'from': 'human',\n",
|
| 188 |
+
" 'value': [{'type': 'audio',\n",
|
| 189 |
+
" 'audio': audio_path}]})\n",
|
| 190 |
+
" words_from_poi = []\n",
|
| 191 |
+
" for hist in chat_history:\n",
|
| 192 |
+
" if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:\n",
|
| 193 |
+
" tmp = json.loads(hist['value'])\n",
|
| 194 |
+
" for i,poi in enumerate(tmp['poi']):\n",
|
| 195 |
+
" words_from_poi.append(poi['name'])\n",
|
| 196 |
+
" for hist in chat_history:\n",
|
| 197 |
+
" if hist['from']=='human' and type(hist['value'])==str:\n",
|
| 198 |
+
" hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi)\n",
|
| 199 |
+
" elif hist['from']=='function_call' and \"arguments\" in hist['value'] and 'keyword' in hist['value'][\"arguments\"]:\n",
|
| 200 |
+
" hist['value'][\"arguments\"] = eval(hist['value'][\"arguments\"])\n",
|
| 201 |
+
" if 'keyword' in hist['value'][\"arguments\"]:\n",
|
| 202 |
+
" hist['value'][\"arguments\"]['keyword'] = correct_sentence_with_pinyin(hist['value'][\"arguments\"]['keyword'],words_from_poi)\n",
|
| 203 |
+
" hist['value'][\"arguments\"] = str(hist['value'][\"arguments\"])\n",
|
| 204 |
+
" # model_input_history = copy.deepcopy(chat_history)\n",
|
| 205 |
+
" # num2ch = {1:'一',2:'二',3:'三',4:'四',5:'五',6:'六'}\n",
|
| 206 |
+
" # for hist in model_input_history:\n",
|
| 207 |
+
" # if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:\n",
|
| 208 |
+
" # tmp = json.loads(hist['value'])\n",
|
| 209 |
+
" # new_poi = []\n",
|
| 210 |
+
" # for i,poi in enumerate(tmp['poi']):\n",
|
| 211 |
+
" # new_poi.append('第{}個 : '.format(num2ch[i+1])+str(poi))\n",
|
| 212 |
+
" # tmp['poi'] = new_poi\n",
|
| 213 |
+
" # hist['value'] = json.dumps(tmp, ensure_ascii=False)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
" inputs_text = self.processor.apply_chat_template(\n",
|
| 216 |
+
" chat_history, add_generation_prompt=True, tokenize=False,\n",
|
| 217 |
+
" return_dict=True, return_tensors=\"pt\", tools=json.loads(tools)\n",
|
| 218 |
+
" )\n",
|
| 219 |
+
" inputs_tensor = self.processor(text=inputs_text, \n",
|
| 220 |
+
" audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None, \n",
|
| 221 |
+
" add_special_tokens=False, \n",
|
| 222 |
+
" return_tensors='pt'\n",
|
| 223 |
+
" )\n",
|
| 224 |
+
" model_output = self.call_gpt(inputs_tensor)\n",
|
| 225 |
+
" if chat_history[-1]['from']=='observation':\n",
|
| 226 |
+
" chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)})\n",
|
| 227 |
+
" return chat_history\n",
|
| 228 |
+
" if ((not ';\\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\\\n",
|
| 229 |
+
" or len(model_output.split(';\\n'))<2 ):\n",
|
| 230 |
+
" if chat_history[-1]['value']!=\"抱歉我聽不清楚 能麻煩您再說一次嗎\":\n",
|
| 231 |
+
" chat_history.append({'from': 'human',\n",
|
| 232 |
+
" 'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}),\n",
|
| 233 |
+
" chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'})\n",
|
| 234 |
+
" return chat_history\n",
|
| 235 |
+
" output_t,output_o = model_output.split(';\\n')[:2]\n",
|
| 236 |
+
" output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o)\n",
|
| 237 |
+
" chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi)\n",
|
| 238 |
+
" if 'Action:' in output_o and 'ActionInput:' in output_o: # function calling\n",
|
| 239 |
+
" function_name,function_arg = output_o.split('ActionInput:')\n",
|
| 240 |
+
" function_name = function_name.replace('Action:','')\n",
|
| 241 |
+
" if \"keyword\" in function_arg:\n",
|
| 242 |
+
" function_arg = json.loads(function_arg)\n",
|
| 243 |
+
" if \"keyword\" in function_arg:\n",
|
| 244 |
+
" function_arg[\"keyword\"] = correct_sentence_with_pinyin(function_arg[\"keyword\"],words_from_poi)\n",
|
| 245 |
+
" chat_history.append({'from': 'function_call', 'value': {\"name\": function_name, \"arguments\": str(function_arg)}})\n",
|
| 246 |
+
" else: # gpt response\n",
|
| 247 |
+
" chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)})\n",
|
| 248 |
+
" return chat_history\n"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"cell_type": "code",
|
| 253 |
+
"execution_count": 6,
|
| 254 |
+
"metadata": {},
|
| 255 |
+
"outputs": [
|
| 256 |
+
{
|
| 257 |
+
"name": "stderr",
|
| 258 |
+
"output_type": "stream",
|
| 259 |
+
"text": [
|
| 260 |
+
"/home/jeff/.cache/huggingface/modules/transformers_modules/gemma-3-4b-it-omni/speech_conformer_encoder.py:2775: FutureWarning: Please specify CheckpointImpl.NO_REENTRANT as CheckpointImpl.REENTRANT will soon be removed as the default and eventually deprecated.\n",
|
| 261 |
+
" lambda i: encoder_checkpoint_wrapper(\n"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"name": "stdout",
|
| 266 |
+
"output_type": "stream",
|
| 267 |
+
"text": [
|
| 268 |
+
"######################## speech lora #############\n",
|
| 269 |
+
"######################## text lora #############\n"
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"name": "stderr",
|
| 274 |
+
"output_type": "stream",
|
| 275 |
+
"text": [
|
| 276 |
+
"Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00, 1.22s/it]\n"
|
| 277 |
+
]
|
| 278 |
+
}
|
| 279 |
+
],
|
| 280 |
+
"source": [
|
| 281 |
+
"pipeline = InferenceClass(model_id)"
|
| 282 |
+
]
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"cell_type": "code",
|
| 286 |
+
"execution_count": 9,
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"outputs": [
|
| 289 |
+
{
|
| 290 |
+
"name": "stderr",
|
| 291 |
+
"output_type": "stream",
|
| 292 |
+
"text": [
|
| 293 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
|
| 294 |
+
" warnings.warn(\n",
|
| 295 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `64` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
|
| 296 |
+
" warnings.warn(\n",
|
| 297 |
+
"/home/jeff/miniconda3/envs/py10/lib/python3.10/site-packages/torch/utils/checkpoint.py:85: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
|
| 298 |
+
" warnings.warn(\n"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"name": "stdout",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}]\n",
|
| 306 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}]\n",
|
| 307 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}]\n",
|
| 308 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}]\n",
|
| 309 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}]\n",
|
| 310 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}]\n",
|
| 311 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}]\n",
|
| 312 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}]\n",
|
| 313 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}]\n",
|
| 314 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}]\n",
|
| 315 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}]\n",
|
| 316 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}]\n",
|
| 317 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '請開啟電滑門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_SLIDING_DOOR_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"true\"}]}'}}]\n",
|
| 318 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '請開啟電滑門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_SLIDING_DOOR_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"true\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電滑門已開啟'}]\n",
|
| 319 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '請開啟電滑門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_SLIDING_DOOR_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"true\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電滑門已開啟'}, {'from': 'human', 'value': '中間這排左邊關空調'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_AC_ON\",\"areaId\":\"SEAT_ROW_2_LEFT\",\"operation\":\"set\",\"value\":\"false\"}]}'}}]\n",
|
| 320 |
+
"[{'from': 'human', 'value': '馬上關充電口'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"EV_CHARGE_PORT_OPEN\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '充電口已關閉'}, {'from': 'human', 'value': '立刻關掉電尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': '請第2排中間靠背向後'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SEAT_BACKREST_ANGLE_1_POS\",\"areaId\":\"SEAT_ROW_2_CENTER\",\"operation\":\"decrease\",\"value\":\"5\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排中座椅背已向後調整'}, {'from': 'human', 'value': '幫我關閉電動尾門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_POWER_TAILGATE_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電動尾門已關閉'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '直接將最後排左邊溫度調弱到31度'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_TEMPERATURE_SET\",\"areaId\":\"SEAT_ROW_3_LEFT\",\"operation\":\"set\",\"value\":\"31\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第三排左座空調已設定為31度'}, {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'}, {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'}, {'from': 'human', 'value': '請開啟電滑門'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"SRDC_SLIDING_DOOR_ON\",\"areaId\":\"GLOBAL\",\"operation\":\"set\",\"value\":\"true\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '電滑門已開啟'}, {'from': 'human', 'value': '中間這排左邊關空調'}, {'from': 'function_call', 'value': {'name': 'control_car_properties', 'arguments': '{\"properties\":[{\"propertyId\":\"HVAC_AC_ON\",\"areaId\":\"SEAT_ROW_2_LEFT\",\"operation\":\"set\",\"value\":\"false\"}]}'}}, {'from': 'observation', 'value': '{\"name\": \"control_car_properties\", \"status\": \"success\", \"message\": \"控制指令執行完成\"}'}, {'from': 'gpt', 'value': '第二排左座已關閉空調'}]\n"
|
| 321 |
+
]
|
| 322 |
+
}
|
| 323 |
+
],
|
| 324 |
+
"source": [
|
| 325 |
+
"import json\n",
|
| 326 |
+
"data = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/ctrl_0730_noisy.json'))\n",
|
| 327 |
+
"for all_conv in data:\n",
|
| 328 |
+
" history = []\n",
|
| 329 |
+
" tools = all_conv['tools']\n",
|
| 330 |
+
" for idx,conv in enumerate(all_conv['conversations']):\n",
|
| 331 |
+
" if conv['from']=='function_call' or conv['from']=='gpt':continue\n",
|
| 332 |
+
" elif conv['from']=='human':\n",
|
| 333 |
+
" history = pipeline.generate(history,tools=tools,audio_path=conv['audio_path'])\n",
|
| 334 |
+
" elif conv['from']=='observation':\n",
|
| 335 |
+
" history = pipeline.call_function_fake(history,conv['value'])\n",
|
| 336 |
+
" history = pipeline.generate(history,tools=tools)\n",
|
| 337 |
+
" print(history)\n",
|
| 338 |
+
" break"
|
| 339 |
+
]
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"cell_type": "code",
|
| 343 |
+
"execution_count": null,
|
| 344 |
+
"metadata": {},
|
| 345 |
+
"outputs": [
|
| 346 |
+
{
|
| 347 |
+
"data": {
|
| 348 |
+
"text/plain": [
|
| 349 |
+
"'我在找台中太平逸境'"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
"execution_count": 25,
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"output_type": "execute_result"
|
| 355 |
+
}
|
| 356 |
+
],
|
| 357 |
+
"source": []
|
| 358 |
+
},
|
| 359 |
+
{
|
| 360 |
+
"cell_type": "code",
|
| 361 |
+
"execution_count": null,
|
| 362 |
+
"metadata": {},
|
| 363 |
+
"outputs": [],
|
| 364 |
+
"source": []
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"cell_type": "code",
|
| 368 |
+
"execution_count": 17,
|
| 369 |
+
"metadata": {},
|
| 370 |
+
"outputs": [
|
| 371 |
+
{
|
| 372 |
+
"data": {
|
| 373 |
+
"text/plain": [
|
| 374 |
+
"{'keyword': '台東太平逸境'}"
|
| 375 |
+
]
|
| 376 |
+
},
|
| 377 |
+
"execution_count": 17,
|
| 378 |
+
"metadata": {},
|
| 379 |
+
"output_type": "execute_result"
|
| 380 |
+
}
|
| 381 |
+
],
|
| 382 |
+
"source": [
|
| 383 |
+
"eval(history[1]['value'][\"arguments\"])"
|
| 384 |
+
]
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"cell_type": "code",
|
| 388 |
+
"execution_count": 12,
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"outputs": [
|
| 391 |
+
{
|
| 392 |
+
"data": {
|
| 393 |
+
"text/plain": [
|
| 394 |
+
"'{\"name\": \"search_and_show_place\", \"arguments\": {\"keyword\": \"台中太平逸境\"}}'"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
"execution_count": 12,
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"output_type": "execute_result"
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
+
"source": [
|
| 403 |
+
"all_conv['conversations'][1]['value']"
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"cell_type": "code",
|
| 408 |
+
"execution_count": 7,
|
| 409 |
+
"metadata": {},
|
| 410 |
+
"outputs": [
|
| 411 |
+
{
|
| 412 |
+
"data": {
|
| 413 |
+
"text/plain": [
|
| 414 |
+
"[{'from': 'human', 'value': '我在找台東太平逸境'},\n",
|
| 415 |
+
" {'from': 'function_call',\n",
|
| 416 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 417 |
+
" 'arguments': '{\"keyword\":\"台東太平逸境\"}'}},\n",
|
| 418 |
+
" {'from': 'observation',\n",
|
| 419 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"台中太平逸境\", \"address\": \"台中市太平區新仁路一段88號\", \"latitude\": 24.109912, \"longitude\": 120.737201, \"fromMemory\": false}]}'},\n",
|
| 420 |
+
" {'from': 'gpt', 'value': '我找到「台中太平逸境」,請問是要到這裡嗎?'},\n",
|
| 421 |
+
" {'from': 'human', 'value': '去第一個'},\n",
|
| 422 |
+
" {'from': 'function_call',\n",
|
| 423 |
+
" 'value': {'name': 'nav_start',\n",
|
| 424 |
+
" 'arguments': '{\"keyword\":\"台東太平逸境\",\"name\":\"台中太平逸境\",\"address\":\"台中市太平區新仁路一段88號\",\"latitude\":24.109912,\"longitude\":120.737201,\"fromMemory\":false}'}},\n",
|
| 425 |
+
" {'from': 'observation',\n",
|
| 426 |
+
" 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'},\n",
|
| 427 |
+
" {'from': 'gpt', 'value': '導航已設定,往「台中太平逸境」出發。'},\n",
|
| 428 |
+
" {'from': 'human', 'value': '請帶我去澳洲大堡道科韓的書店'},\n",
|
| 429 |
+
" {'from': 'function_call',\n",
|
| 430 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 431 |
+
" 'arguments': '{\"keyword\":\"澳洲大堡道科韓的書店\"}'}},\n",
|
| 432 |
+
" {'from': 'observation',\n",
|
| 433 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"咔啦咔啦雞\", \"address\": \"南投市中山路120號\", \"latitude\": 23.9133, \"longitude\": 120.685, \"fromMemory\": true}, {\"name\": \"炸物車站\", \"address\": \"新竹市香山區中華路六段88號\", \"latitude\": 24.776, \"longitude\": 120.9331, \"fromMemory\": true}]}'},\n",
|
| 434 |
+
" {'from': 'gpt', 'value': '上次您是去「咔啦咔啦雞」,這次還要去這裡嗎?'},\n",
|
| 435 |
+
" {'from': 'human', 'value': '這好像不像吧'},\n",
|
| 436 |
+
" {'from': 'gpt',\n",
|
| 437 |
+
" 'value': '請告訴��目的地附近的標誌性景點或描述「澳洲大堡道科韓的書店」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 438 |
+
" {'from': 'human', 'value': '這裡好像不對'},\n",
|
| 439 |
+
" {'from': 'gpt',\n",
|
| 440 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「澳洲大堡道科韓的書店」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 441 |
+
" {'from': 'human', 'value': '我有印象他們有大片玻璃門面'},\n",
|
| 442 |
+
" {'from': 'function_call',\n",
|
| 443 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 444 |
+
" 'arguments': '{\"keyword\":\"澳洲大堡道科韓的書店大片玻璃門面\"}'}},\n",
|
| 445 |
+
" {'from': 'observation',\n",
|
| 446 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"雞丁炸舖\", \"address\": \"台東市中華路一段210號\", \"latitude\": 22.7575, \"longitude\": 121.1457, \"fromMemory\": false}, {\"name\": \"即刻炸\", \"address\": \"桃園市桃園區中正路400號\", \"latitude\": 24.9933, \"longitude\": 121.3121, \"fromMemory\": false}, {\"name\": \"BOOM炸雞\", \"address\": \"台南市永康區中華路88號\", \"latitude\": 23.0349, \"longitude\": 120.25, \"fromMemory\": false}, {\"name\": \"雞咔咔速食站\", \"address\": \"台中市南區建成路160號\", \"latitude\": 24.1217, \"longitude\": 120.6852, \"fromMemory\": false}, {\"name\": \"炸雞一下\", \"address\": \"彰化市民生路188號\", \"latitude\": 24.0801, \"longitude\": 120.5429, \"fromMemory\": false}]}'},\n",
|
| 447 |
+
" {'from': 'gpt', 'value': '我找到以下與「澳洲大堡道科韓的書店」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 448 |
+
" {'from': 'human', 'value': '還是算了,去比較好'},\n",
|
| 449 |
+
" {'from': 'gpt', 'value': '已為您取消導航清單'},\n",
|
| 450 |
+
" {'from': 'human', 'value': '還是算了,不去比較好'},\n",
|
| 451 |
+
" {'from': 'gpt', 'value': '已為您取消導航清單'}]"
|
| 452 |
+
]
|
| 453 |
+
},
|
| 454 |
+
"execution_count": 7,
|
| 455 |
+
"metadata": {},
|
| 456 |
+
"output_type": "execute_result"
|
| 457 |
+
}
|
| 458 |
+
],
|
| 459 |
+
"source": [
|
| 460 |
+
"history"
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"cell_type": "code",
|
| 465 |
+
"execution_count": 3,
|
| 466 |
+
"metadata": {},
|
| 467 |
+
"outputs": [
|
| 468 |
+
{
|
| 469 |
+
"data": {
|
| 470 |
+
"text/plain": [
|
| 471 |
+
"[]"
|
| 472 |
+
]
|
| 473 |
+
},
|
| 474 |
+
"execution_count": 3,
|
| 475 |
+
"metadata": {},
|
| 476 |
+
"output_type": "execute_result"
|
| 477 |
+
}
|
| 478 |
+
],
|
| 479 |
+
"source": [
|
| 480 |
+
"a = []\n",
|
| 481 |
+
"import json\n",
|
| 482 |
+
"json.loads(str(a))"
|
| 483 |
+
]
|
| 484 |
+
},
|
| 485 |
+
{
|
| 486 |
+
"cell_type": "code",
|
| 487 |
+
"execution_count": null,
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"outputs": [],
|
| 490 |
+
"source": []
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"cell_type": "code",
|
| 494 |
+
"execution_count": 7,
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"outputs": [
|
| 497 |
+
{
|
| 498 |
+
"data": {
|
| 499 |
+
"text/plain": [
|
| 500 |
+
"[{'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'},\n",
|
| 501 |
+
" {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'},\n",
|
| 502 |
+
" {'from': 'human', 'value': '請帶我去義式鄉村風的海灘式漢堡'},\n",
|
| 503 |
+
" {'from': 'function_call',\n",
|
| 504 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 505 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡\"}'}},\n",
|
| 506 |
+
" {'from': 'observation',\n",
|
| 507 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Benny’s漢堡坊\", \"address\": \"新北市三重區重新路四段201號\", \"latitude\": 25.0601, \"longitude\": 121.4876, \"fromMemory\": true}, {\"name\": \"Chubby’s Burger\", \"address\": \"台中市西屯區黎明路三段122號\", \"latitude\": 24.1812, \"longitude\": 120.6422, \"fromMemory\": true}]}'},\n",
|
| 508 |
+
" {'from': 'gpt', 'value': '上次您是去「Benny’s漢堡坊」,這次還要去這裡嗎?'},\n",
|
| 509 |
+
" {'from': 'human', 'value': '這個看起來不符合'},\n",
|
| 510 |
+
" {'from': 'gpt',\n",
|
| 511 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「義式鄉村風的海灘式漢堡」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 512 |
+
" {'from': 'human', 'value': '店名叫大和水셰店'},\n",
|
| 513 |
+
" {'from': 'function_call',\n",
|
| 514 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 515 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡大和水셰店\"}'}},\n",
|
| 516 |
+
" {'from': 'observation',\n",
|
| 517 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Classic Burger屋\", \"address\": \"高雄市鼓山區裕誠路350號\", \"latitude\": 22.6503, \"longitude\": 120.2911, \"fromMemory\": false}, {\"name\": \"Burger Kingdom\", \"address\": \"花蓮市中正路310號\", \"latitude\": 23.9782, \"longitude\": 121.6023, \"fromMemory\": false}, {\"name\": \"大口咬咬漢堡店\", \"address\": \"台北市大安區安和路二段95號\", \"latitude\": 25.03, \"longitude\": 121.5509, \"fromMemory\": false}]}'},\n",
|
| 518 |
+
" {'from': 'gpt', 'value': '我找到以下與「義式鄉村風的海灘式漢堡」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 519 |
+
" {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'},\n",
|
| 520 |
+
" {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'},\n",
|
| 521 |
+
" {'from': 'human', 'value': '全錯'},\n",
|
| 522 |
+
" {'from': 'gpt',\n",
|
| 523 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「義式鄉村風的海灘式漢堡」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 524 |
+
" {'from': 'human', 'value': '地點選在松山文創園區旁邊,特色為日式壽司'},\n",
|
| 525 |
+
" {'from': 'function_call',\n",
|
| 526 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 527 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡松山文創園區旁邊日式壽司\"}'}},\n",
|
| 528 |
+
" {'from': 'observation',\n",
|
| 529 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Station 13漢堡屋\", \"address\": \"新北市土城區中央路三段260號\", \"latitude\": 24.979, \"longitude\": 121.4456, \"fromMemory\": false}, {\"name\": \"牛仔漢堡屋\", \"address\": \"台南市永康區中正南路150號\", \"latitude\": 23.0271, \"longitude\": 120.2485, \"fromMemory\": false}, {\"name\": \"Happy Cow美式餐館\", \"address\": \"高雄市左營區自由三路88號\", \"latitude\": 22.6845, \"longitude\": 120.3077, \"fromMemory\": false}, {\"name\": \"Chef John’s Burger\", \"address\": \"台中市太平區太平路250號\", \"latitude\": 24.1275, \"longitude\": 120.728, \"fromMemory\": false}, {\"name\": \"Rolling Burger\", \"address\": \"嘉義市東區忠孝路88號\", \"latitude\": 23.4866, \"longitude\": 120.4511, \"fromMemory\": false}, {\"name\": \"Wagyu Burger\", \"address\": \"新竹市東區金山街101號\", \"latitude\": 24.8001, \"longitude\": 121.0102, \"fromMemory\": false}]}'},\n",
|
| 530 |
+
" {'from': 'gpt', 'value': '我找到以下與「義式鄉村風的海灘式漢堡」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 531 |
+
" {'from': 'human', 'value': '沒有選中'},\n",
|
| 532 |
+
" {'from': 'gpt',\n",
|
| 533 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「義式鄉村風的海灘式漢堡」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 534 |
+
" {'from': 'human', 'value': '地點應該是板橋大遠百對面,那家好像叫什麼懸的'},\n",
|
| 535 |
+
" {'from': 'function_call',\n",
|
| 536 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 537 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡板橋大遠百對面懸\"}'}},\n",
|
| 538 |
+
" {'from': 'observation',\n",
|
| 539 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Classic Burger屋\", \"address\": \"高雄市鼓山區裕誠路350號\", \"latitude\": 22.6503, \"longitude\": 120.2911, \"fromMemory\": false}, {\"name\": \"大漢堡俱樂部\", \"address\": \"台中市北區雙十路二段170號\", \"latitude\": 24.1578, \"longitude\": 120.6842, \"fromMemory\": false}, {\"name\": \"Burger Bank\", \"address\": \"新竹市北區湳雅街55號\", \"latitude\": 24.8154, \"longitude\": 120.9766, \"fromMemory\": false}, {\"name\": \"熊堡堡\", \"address\": \"高雄市楠梓區德賢路89號\", \"latitude\": 22.7261, \"longitude\": 120.3319, \"fromMemory\": false}]}'},\n",
|
| 540 |
+
" {'from': 'gpt', 'value': '我找到以下與「義式鄉村風的海灘式漢堡」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 541 |
+
" {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'},\n",
|
| 542 |
+
" {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'},\n",
|
| 543 |
+
" {'from': 'human', 'value': '沒看到對的'},\n",
|
| 544 |
+
" {'from': 'gpt',\n",
|
| 545 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「義式鄉村風的海灘式漢堡」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 546 |
+
" {'from': 'human', 'value': '位在公館商圈,有西班牙燉飯'},\n",
|
| 547 |
+
" {'from': 'function_call',\n",
|
| 548 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 549 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡公館商圈西班牙燉飯\"}'}},\n",
|
| 550 |
+
" {'from': 'observation',\n",
|
| 551 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Rolling Burger\", \"address\": \"嘉義市東區忠孝路88號\", \"latitude\": 23.4866, \"longitude\": 120.4511, \"fromMemory\": false}, {\"name\": \"Barrel & Beef\", \"address\": \"宜蘭市女中路三段122號\", \"latitude\": 24.7576, \"longitude\": 121.7523, \"fromMemory\": false}]}'},\n",
|
| 552 |
+
" {'from': 'gpt', 'value': '我找到以下與「義式鄉村風的海灘式漢堡」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 553 |
+
" {'from': 'human', 'value': '我找不到我要的選項'},\n",
|
| 554 |
+
" {'from': 'gpt',\n",
|
| 555 |
+
" 'value': '請告訴我目的地附近的標誌性景點或描述「義式鄉村風的海灘式漢堡」的具體特徵,這樣我可以更精準地幫你搜尋並導航到正確的地址'},\n",
|
| 556 |
+
" {'from': 'human', 'value': '印象中是什麼B開頭的店名'},\n",
|
| 557 |
+
" {'from': 'function_call',\n",
|
| 558 |
+
" 'value': {'name': 'search_and_show_place',\n",
|
| 559 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡B\"}'}},\n",
|
| 560 |
+
" {'from': 'observation',\n",
|
| 561 |
+
" 'value': '{\"name\": \"search_and_show_place\", \"status\": \"success\", \"message\": \"地點查詢成功\", \"poi\": [{\"name\": \"Station 13漢堡屋\", \"address\": \"新北市土城區中央路三段260號\", \"latitude\": 24.979, \"longitude\": 121.4456, \"fromMemory\": false}, {\"name\": \"Burger Bank\", \"address\": \"新竹市北區湳雅街55號\", \"latitude\": 24.8154, \"longitude\": 120.9766, \"fromMemory\": false}, {\"name\": \"Big Jack漢堡屋\", \"address\": \"台北市中山區民生東路三段56號\", \"latitude\": 25.0563, \"longitude\": 121.5387, \"fromMemory\": false}, {\"name\": \"Happy Cow美式餐館\", \"address\": \"高雄市左營區自由三路88號\", \"latitude\": 22.6845, \"longitude\": 120.3077, \"fromMemory\": false}, {\"name\": \"美式胖堡\", \"address\": \"台南市中西區西門路二段210號\", \"latitude\": 22.9987, \"longitude\": 120.1976, \"fromMemory\": false}, {\"name\": \"大口咬咬漢堡店\", \"address\": \"台北市大安區安和路二段95號\", \"latitude\": 25.03, \"longitude\": 121.5509, \"fromMemory\": false}]}'},\n",
|
| 562 |
+
" {'from': 'gpt', 'value': '我找到以下與「義式鄉村風的海灘式漢堡」相關的地點,請問您要導航到哪一個地點?'},\n",
|
| 563 |
+
" {'from': 'human', 'value': '我記得去了'},\n",
|
| 564 |
+
" {'from': 'gpt', 'value': '已為您取消導航清單'},\n",
|
| 565 |
+
" {'from': 'human', 'value': '直奔第五個'},\n",
|
| 566 |
+
" {'from': 'function_call',\n",
|
| 567 |
+
" 'value': {'name': 'nav_start',\n",
|
| 568 |
+
" 'arguments': '{\"keyword\":\"義式鄉村風的海灘式漢堡\",\"name\":\"美式胖堡\",\"address\":\"台南市中西區西門路二段210號\",\"latitude\":22.9987,\"longitude\":120.1976,\"fromMemory\":false}'}},\n",
|
| 569 |
+
" {'from': 'observation',\n",
|
| 570 |
+
" 'value': '{\"name\": \"nav_start\", \"status\": \"success\", \"message\": \"導航已啟動\"}'},\n",
|
| 571 |
+
" {'from': 'gpt', 'value': '導航已設定,往「美式胖堡」出發,祝您用餐愉快!'},\n",
|
| 572 |
+
" {'from': 'human', 'value': 'HUMAN_VOICE_CANNOT_RECOGNIZE'},\n",
|
| 573 |
+
" {'from': 'gpt', 'value': '抱歉我聽不清楚能麻煩您再說一次嗎'},\n",
|
| 574 |
+
" {'from': 'human', 'value': '請終止目前的路線安排'},\n",
|
| 575 |
+
" {'from': 'function_call', 'value': {'name': 'nav_stop', 'arguments': '{}'}},\n",
|
| 576 |
+
" {'from': 'observation',\n",
|
| 577 |
+
" 'value': '{\"name\": \"nav_stop\", \"status\": \"success\", \"message\": \"導航已取消\"}'},\n",
|
| 578 |
+
" {'from': 'gpt', 'value': '已為您取消導航,有需要再出發的話可以隨時告訴我喔!'}]"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
"execution_count": 7,
|
| 582 |
+
"metadata": {},
|
| 583 |
+
"output_type": "execute_result"
|
| 584 |
+
}
|
| 585 |
+
],
|
| 586 |
+
"source": [
|
| 587 |
+
"history"
|
| 588 |
+
]
|
| 589 |
+
},
|
| 590 |
+
{
|
| 591 |
+
"cell_type": "code",
|
| 592 |
+
"execution_count": null,
|
| 593 |
+
"metadata": {},
|
| 594 |
+
"outputs": [],
|
| 595 |
+
"source": [
|
| 596 |
+
"data[1]['conversations']"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"cell_type": "code",
|
| 601 |
+
"execution_count": null,
|
| 602 |
+
"metadata": {},
|
| 603 |
+
"outputs": [],
|
| 604 |
+
"source": [
|
| 605 |
+
"print(pipeline.processor.apply_chat_template(\n",
|
| 606 |
+
" history, add_generation_prompt=True, tokenize=False,\n",
|
| 607 |
+
" return_dict=True, return_tensors=\"pt\", tools=json.loads(tools)\n",
|
| 608 |
+
" ))"
|
| 609 |
+
]
|
| 610 |
+
},
|
| 611 |
+
{
|
| 612 |
+
"cell_type": "code",
|
| 613 |
+
"execution_count": null,
|
| 614 |
+
"metadata": {},
|
| 615 |
+
"outputs": [],
|
| 616 |
+
"source": [
|
| 617 |
+
"history = [{'from': 'function_call', 'value': {\"name\": json.loads(tmp['value'])[\"name\"], \"arguments\": str(json.loads(tmp['value'])[\"arguments\"])}} \n",
|
| 618 |
+
" if tmp['from']=='function_call' else tmp for tmp in data[0]['conversations']]"
|
| 619 |
+
]
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"cell_type": "code",
|
| 623 |
+
"execution_count": null,
|
| 624 |
+
"metadata": {},
|
| 625 |
+
"outputs": [],
|
| 626 |
+
"source": [
|
| 627 |
+
"pipeline.generate(history[:4],tools=data[0]['tools'],audio_path='/home/jeff/jeff/codes/llm/InCar/data/test_data/audio_noisy/ctrl_toolcall_train 2-00015.wav')"
|
| 628 |
+
]
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "code",
|
| 632 |
+
"execution_count": null,
|
| 633 |
+
"metadata": {},
|
| 634 |
+
"outputs": [],
|
| 635 |
+
"source": []
|
| 636 |
+
},
|
| 637 |
+
{
|
| 638 |
+
"cell_type": "code",
|
| 639 |
+
"execution_count": null,
|
| 640 |
+
"metadata": {},
|
| 641 |
+
"outputs": [],
|
| 642 |
+
"source": [
|
| 643 |
+
"import json\n",
|
| 644 |
+
"data = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))\n",
|
| 645 |
+
"data[0]['conversations'][-2] = {'from':'human',\n",
|
| 646 |
+
" 'value':[{\"type\": \"audio\", \"audio\": '/home/jeff/jeff/codes/llm/InCar/data/test_data/audio_noisy/nav_toolcall_train_0730-00005.wav'}]\n",
|
| 647 |
+
"}\n",
|
| 648 |
+
"{\n",
|
| 649 |
+
" \"role\": \"user\",\n",
|
| 650 |
+
" \"content\": [\n",
|
| 651 |
+
" # ans is what_is_shown_in_this_image\n",
|
| 652 |
+
" {\"type\": \"audio\", \"audio\": '/home/jeff/jeff/codes/llm/InCar/data/test_data/audio_noisy/nav_toolcall_train_0730-00003_purenoisy.wav'},\n",
|
| 653 |
+
" {\"type\": \"text\", \"text\": \"Transcribe this audio clip into text.\"}\n",
|
| 654 |
+
" ]\n",
|
| 655 |
+
" }\n",
|
| 656 |
+
"for conv in data[0]['conversations']:\n",
|
| 657 |
+
" if conv['from']=='function_call':\n",
|
| 658 |
+
" conv['value']=json.loads(conv['value'])\n",
|
| 659 |
+
" conv['value']['arguments'] = str(conv['value']['arguments'])\n",
|
| 660 |
+
"print(processor.apply_chat_template(data[0]['conversations'], add_generation_prompt=True, tokenize=False,\n",
|
| 661 |
+
" return_dict=True,tools=json.loads(data[0]['tools'])\n",
|
| 662 |
+
" ))"
|
| 663 |
+
]
|
| 664 |
+
},
|
| 665 |
+
{
|
| 666 |
+
"cell_type": "code",
|
| 667 |
+
"execution_count": null,
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"outputs": [],
|
| 670 |
+
"source": [
|
| 671 |
+
"audio_path = '/home/jeff/jeff/codes/llm/InCar/data/test_data/audio_noisy/nav_toolcall_train_0730-00005.wav'\n",
|
| 672 |
+
"messages = [\n",
|
| 673 |
+
" {\n",
|
| 674 |
+
" \"from\": \"human\",\n",
|
| 675 |
+
" \"value\": [\n",
|
| 676 |
+
" # ans is what_is_shown_in_this_image\n",
|
| 677 |
+
" {\"type\": \"audio\", \"audio\": audio_path},\n",
|
| 678 |
+
" {\"type\": \"text\", \"text\": \"Transcribe this audio clip into text.\"}\n",
|
| 679 |
+
" ]\n",
|
| 680 |
+
" }\n",
|
| 681 |
+
"]\n",
|
| 682 |
+
"\n",
|
| 683 |
+
"inputs_text = processor.apply_chat_template(\n",
|
| 684 |
+
" data[0]['conversations'][:-1], add_generation_prompt=True, tokenize=False,\n",
|
| 685 |
+
" return_dict=True, return_tensors=\"pt\", tools={}\n",
|
| 686 |
+
")\n",
|
| 687 |
+
"inputs = processor(text=inputs_text, \n",
|
| 688 |
+
" audio=[torchaudio.load(audio_path)[0]], \n",
|
| 689 |
+
" add_special_tokens=False, \n",
|
| 690 |
+
" return_tensors='pt'\n",
|
| 691 |
+
" )\n",
|
| 692 |
+
"\n",
|
| 693 |
+
"with torch.inference_mode():\n",
|
| 694 |
+
" inputs = {k:inputs[k].to('cuda') for k in inputs}\n",
|
| 695 |
+
" generate_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)\n",
|
| 696 |
+
" generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]\n",
|
| 697 |
+
" response = processor.batch_decode(\n",
|
| 698 |
+
" generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
| 699 |
+
" )[0]\n",
|
| 700 |
+
"print(response)\n"
|
| 701 |
+
]
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"cell_type": "code",
|
| 705 |
+
"execution_count": null,
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"outputs": [],
|
| 708 |
+
"source": []
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"cell_type": "code",
|
| 712 |
+
"execution_count": null,
|
| 713 |
+
"metadata": {},
|
| 714 |
+
"outputs": [],
|
| 715 |
+
"source": [
|
| 716 |
+
"inputs_text"
|
| 717 |
+
]
|
| 718 |
+
},
|
| 719 |
+
{
|
| 720 |
+
"cell_type": "code",
|
| 721 |
+
"execution_count": null,
|
| 722 |
+
"metadata": {},
|
| 723 |
+
"outputs": [],
|
| 724 |
+
"source": [
|
| 725 |
+
"inputs = processor.apply_chat_template(\n",
|
| 726 |
+
" messages, add_generation_prompt=True, tokenize=True,\n",
|
| 727 |
+
" return_dict=True, return_tensors=\"pt\", tools=json.loads(data[0]['tools'])\n",
|
| 728 |
+
")"
|
| 729 |
+
]
|
| 730 |
+
},
|
| 731 |
+
{
|
| 732 |
+
"cell_type": "code",
|
| 733 |
+
"execution_count": null,
|
| 734 |
+
"metadata": {},
|
| 735 |
+
"outputs": [],
|
| 736 |
+
"source": []
|
| 737 |
+
},
|
| 738 |
+
{
|
| 739 |
+
"cell_type": "code",
|
| 740 |
+
"execution_count": null,
|
| 741 |
+
"metadata": {},
|
| 742 |
+
"outputs": [],
|
| 743 |
+
"source": [
|
| 744 |
+
"from transformers import AutoProcessor, AutoModel,AutoModelForCausalLM\n",
|
| 745 |
+
"import torch\n",
|
| 746 |
+
"model_id = \"/home/jeff/codes/llm/InCar/Phi-4-multimodal-instruct\"\n",
|
| 747 |
+
"revision = \"main\"\n",
|
| 748 |
+
"model2 = AutoModelForCausalLM.from_pretrained(#AutoModel.from_pretrained(\n",
|
| 749 |
+
" model_id, device_map=\"cpu\", \n",
|
| 750 |
+
" revision = revision, trust_remote_code=True,\n",
|
| 751 |
+
" _attn_implementation='flash_attention_2'\n",
|
| 752 |
+
" # torch_dtype=torch.float16\n",
|
| 753 |
+
").eval()"
|
| 754 |
+
]
|
| 755 |
+
},
|
| 756 |
+
{
|
| 757 |
+
"cell_type": "code",
|
| 758 |
+
"execution_count": null,
|
| 759 |
+
"metadata": {},
|
| 760 |
+
"outputs": [],
|
| 761 |
+
"source": [
|
| 762 |
+
"messages = [\n",
|
| 763 |
+
" {\n",
|
| 764 |
+
" \"role\": \"user\",\n",
|
| 765 |
+
" \"content\": [\n",
|
| 766 |
+
" # ans is what_is_shown_in_this_image\n",
|
| 767 |
+
" {\"type\": \"audio\", \"audio\": \"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav\"},\n",
|
| 768 |
+
" {\"type\": \"text\", \"text\": \"Transcribe this audio clip into text.\"}\n",
|
| 769 |
+
" ]\n",
|
| 770 |
+
" }\n",
|
| 771 |
+
"]\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"inputs = processor.apply_chat_template(\n",
|
| 774 |
+
" messages, add_generation_prompt=True, tokenize=True,\n",
|
| 775 |
+
" return_dict=True, return_tensors=\"pt\"\n",
|
| 776 |
+
")\n",
|
| 777 |
+
"\n",
|
| 778 |
+
"# with torch.inference_mode():\n",
|
| 779 |
+
"# inputs = {k:inputs[k].to('cuda') for k in inputs}\n",
|
| 780 |
+
"# generate_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)\n",
|
| 781 |
+
"# generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]\n",
|
| 782 |
+
"# response = processor.batch_decode(\n",
|
| 783 |
+
"# generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
| 784 |
+
"# )[0]\n",
|
| 785 |
+
"# print(response)\n"
|
| 786 |
+
]
|
| 787 |
+
},
|
| 788 |
+
{
|
| 789 |
+
"cell_type": "code",
|
| 790 |
+
"execution_count": null,
|
| 791 |
+
"metadata": {},
|
| 792 |
+
"outputs": [],
|
| 793 |
+
"source": [
|
| 794 |
+
"len(inputs['input_ids'][0]),len(inputs['input_audio_embeds'][0])"
|
| 795 |
+
]
|
| 796 |
+
},
|
| 797 |
+
{
|
| 798 |
+
"cell_type": "code",
|
| 799 |
+
"execution_count": null,
|
| 800 |
+
"metadata": {},
|
| 801 |
+
"outputs": [],
|
| 802 |
+
"source": [
|
| 803 |
+
"model2.model.embed_tokens_extend.audio_embed.encoder(inputs['input_audio_embeds'],None)"
|
| 804 |
+
]
|
| 805 |
+
},
|
| 806 |
+
{
|
| 807 |
+
"cell_type": "code",
|
| 808 |
+
"execution_count": null,
|
| 809 |
+
"metadata": {},
|
| 810 |
+
"outputs": [],
|
| 811 |
+
"source": []
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"cell_type": "code",
|
| 815 |
+
"execution_count": null,
|
| 816 |
+
"metadata": {},
|
| 817 |
+
"outputs": [],
|
| 818 |
+
"source": [
|
| 819 |
+
"import json\n",
|
| 820 |
+
"with open('/mnt/data-2t/jeff/codes/llm/InCar/gemma-3-4b-it-omni/output_org_nav_2025-07-24 08:34.json') as f:\n",
|
| 821 |
+
" nav_res = json.load(f)\n",
|
| 822 |
+
"remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\\n','').\\\n",
|
| 823 |
+
" replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\\\n",
|
| 824 |
+
" replace('.','').replace('!','')\n",
|
| 825 |
+
"total_func_call=0\n",
|
| 826 |
+
"func_error=0\n",
|
| 827 |
+
"errors=[]\n",
|
| 828 |
+
"for res in nav_res:\n",
|
| 829 |
+
" if res['cer']!=0:\n",
|
| 830 |
+
" errors.append(res)\n",
|
| 831 |
+
" if 'Action:' in res['label']:\n",
|
| 832 |
+
" func_error+=remove_sign(res['label'])!=remove_sign(res['output'])\n",
|
| 833 |
+
" total_func_call+=1\n",
|
| 834 |
+
"avg_cer = sum(a['cer'] for a in nav_res)/len(nav_res)\n",
|
| 835 |
+
"total_error = sum(a['cer']!=0 for a in nav_res)\n",
|
| 836 |
+
"print('total',len(nav_res))\n",
|
| 837 |
+
"print('total_error & rate',total_error,total_error/len(nav_res))\n",
|
| 838 |
+
"print('avg_cer',avg_cer)\n",
|
| 839 |
+
"print('total_func_call',total_func_call)\n",
|
| 840 |
+
"print('func_error & rate',func_error,',',func_error/total_func_call)"
|
| 841 |
+
]
|
| 842 |
+
},
|
| 843 |
+
{
|
| 844 |
+
"cell_type": "code",
|
| 845 |
+
"execution_count": null,
|
| 846 |
+
"metadata": {},
|
| 847 |
+
"outputs": [],
|
| 848 |
+
"source": [
|
| 849 |
+
"488/11046"
|
| 850 |
+
]
|
| 851 |
+
},
|
| 852 |
+
{
|
| 853 |
+
"cell_type": "code",
|
| 854 |
+
"execution_count": null,
|
| 855 |
+
"metadata": {},
|
| 856 |
+
"outputs": [],
|
| 857 |
+
"source": [
|
| 858 |
+
"aa = []\n",
|
| 859 |
+
"for e in errors:\n",
|
| 860 |
+
" if '}' in e['output'] and remove_sign(e['output'][:e['output'].index('}')+1])==remove_sign(e['label']):continue\n",
|
| 861 |
+
" if remove_sign(e['label']) in remove_sign(e['output']):continue\n",
|
| 862 |
+
" aa.append(e)\n",
|
| 863 |
+
"len(aa)"
|
| 864 |
+
]
|
| 865 |
+
},
|
| 866 |
+
{
|
| 867 |
+
"cell_type": "code",
|
| 868 |
+
"execution_count": null,
|
| 869 |
+
"metadata": {},
|
| 870 |
+
"outputs": [],
|
| 871 |
+
"source": [
|
| 872 |
+
"5/4830"
|
| 873 |
+
]
|
| 874 |
+
},
|
| 875 |
+
{
|
| 876 |
+
"cell_type": "code",
|
| 877 |
+
"execution_count": null,
|
| 878 |
+
"metadata": {},
|
| 879 |
+
"outputs": [],
|
| 880 |
+
"source": [
|
| 881 |
+
"len(aa)"
|
| 882 |
+
]
|
| 883 |
+
},
|
| 884 |
+
{
|
| 885 |
+
"cell_type": "code",
|
| 886 |
+
"execution_count": null,
|
| 887 |
+
"metadata": {},
|
| 888 |
+
"outputs": [],
|
| 889 |
+
"source": [
|
| 890 |
+
"aa"
|
| 891 |
+
]
|
| 892 |
+
},
|
| 893 |
+
{
|
| 894 |
+
"cell_type": "code",
|
| 895 |
+
"execution_count": null,
|
| 896 |
+
"metadata": {},
|
| 897 |
+
"outputs": [],
|
| 898 |
+
"source": []
|
| 899 |
+
},
|
| 900 |
+
{
|
| 901 |
+
"cell_type": "code",
|
| 902 |
+
"execution_count": null,
|
| 903 |
+
"metadata": {},
|
| 904 |
+
"outputs": [],
|
| 905 |
+
"source": [
|
| 906 |
+
"import json\n",
|
| 907 |
+
"with open('/mnt/data-2t/jeff/codes/llm/InCar/gemma-3-4b-it-omni/output_org_ctrl_2025-07-24 05:43.json') as f:\n",
|
| 908 |
+
" nav_res = json.load(f)\n",
|
| 909 |
+
"remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\\n','').\\\n",
|
| 910 |
+
" replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\\\n",
|
| 911 |
+
" replace('.','').replace('!','')\n",
|
| 912 |
+
"total_func_call=0\n",
|
| 913 |
+
"func_error=0\n",
|
| 914 |
+
"errors=[]\n",
|
| 915 |
+
"for res in nav_res:\n",
|
| 916 |
+
" if res['cer']!=0:\n",
|
| 917 |
+
" errors.append(res)\n",
|
| 918 |
+
" if 'Action:' in res['label']:\n",
|
| 919 |
+
" func_error+=remove_sign(res['label'])!=remove_sign(res['output'])\n",
|
| 920 |
+
" total_func_call+=1\n",
|
| 921 |
+
"avg_cer = sum(a['cer'] for a in nav_res)/len(nav_res)\n",
|
| 922 |
+
"total_error = sum(a['cer']!=0 for a in nav_res)\n",
|
| 923 |
+
"print('total',len(nav_res))\n",
|
| 924 |
+
"print('total_error & rate',total_error,total_error/len(nav_res))\n",
|
| 925 |
+
"print('avg_cer',avg_cer)\n",
|
| 926 |
+
"print('total_func_call',total_func_call)\n",
|
| 927 |
+
"print('func_error & rate',func_error,',',func_error/total_func_call)"
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"cell_type": "code",
|
| 932 |
+
"execution_count": null,
|
| 933 |
+
"metadata": {},
|
| 934 |
+
"outputs": [],
|
| 935 |
+
"source": [
|
| 936 |
+
"errors"
|
| 937 |
+
]
|
| 938 |
+
},
|
| 939 |
+
{
|
| 940 |
+
"cell_type": "code",
|
| 941 |
+
"execution_count": null,
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"outputs": [],
|
| 944 |
+
"source": [
|
| 945 |
+
"print(errors[0]['input'])"
|
| 946 |
+
]
|
| 947 |
+
},
|
| 948 |
+
{
|
| 949 |
+
"cell_type": "code",
|
| 950 |
+
"execution_count": null,
|
| 951 |
+
"metadata": {},
|
| 952 |
+
"outputs": [],
|
| 953 |
+
"source": []
|
| 954 |
+
}
|
| 955 |
+
],
|
| 956 |
+
"metadata": {
|
| 957 |
+
"kernelspec": {
|
| 958 |
+
"display_name": "py10",
|
| 959 |
+
"language": "python",
|
| 960 |
+
"name": "python3"
|
| 961 |
+
},
|
| 962 |
+
"language_info": {
|
| 963 |
+
"codemirror_mode": {
|
| 964 |
+
"name": "ipython",
|
| 965 |
+
"version": 3
|
| 966 |
+
},
|
| 967 |
+
"file_extension": ".py",
|
| 968 |
+
"mimetype": "text/x-python",
|
| 969 |
+
"name": "python",
|
| 970 |
+
"nbconvert_exporter": "python",
|
| 971 |
+
"pygments_lexer": "ipython3",
|
| 972 |
+
"version": "3.10.18"
|
| 973 |
+
}
|
| 974 |
+
},
|
| 975 |
+
"nbformat": 4,
|
| 976 |
+
"nbformat_minor": 2
|
| 977 |
+
}
|
test.jinja
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% set tools_json = tools %}
|
| 2 |
+
{{ bos_token }}You have access to the following tools:
|
| 3 |
+
{% for tool in tools_json %}
|
| 4 |
+
> Tool Name: {{ tool.name }}
|
| 5 |
+
Tool Description: {{ tool.description }}
|
| 6 |
+
Tool Args:
|
| 7 |
+
{% if tool.parameters.properties %}
|
| 8 |
+
{% for prop_name, prop_data in tool.parameters.properties.items() %}
|
| 9 |
+
- {{ prop_name }} ({{ prop_data.type }}, {% if prop_name in tool.parameters.required %}required{% else %}optional{% endif %}): {{ prop_data.description }}
|
| 10 |
+
{% endfor %}
|
| 11 |
+
{% else %}
|
| 12 |
+
{% endif %}
|
| 13 |
+
|
| 14 |
+
{% endfor %}, where each item should be object
|
| 15 |
+
Use the following format if using a tool:
|
| 16 |
+
```
|
| 17 |
+
Action: tool name (one of [{{ tools_json | map(attribute='name') | join(', ') }}])
|
| 18 |
+
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
{{'\n'}}
|
| 22 |
+
{%- for message in messages -%}
|
| 23 |
+
{%- if message['from'] == "human" -%}
|
| 24 |
+
{%- if message['value'] is string -%}
|
| 25 |
+
{{'<start_of_turn>user\n' + message['value'] + '<end_of_turn>\n'}}
|
| 26 |
+
{%- elif message['value'] is iterable -%}
|
| 27 |
+
{{'<start_of_turn>user\n'}}
|
| 28 |
+
{%- for item in message['value'] -%}
|
| 29 |
+
{%- if item['type'] == 'image' -%}
|
| 30 |
+
{{ '<start_of_image>' }}
|
| 31 |
+
{%- elif item['type'] == 'audio' -%}
|
| 32 |
+
{{ '<start_of_audio>' }}
|
| 33 |
+
{%- elif item['type'] == 'text' -%}
|
| 34 |
+
{{ item['text'] | trim }}
|
| 35 |
+
{%- endif -%}
|
| 36 |
+
{{'<end_of_turn>\n'}}
|
| 37 |
+
{%- endfor -%}
|
| 38 |
+
{%- endif -%}
|
| 39 |
+
{%- elif message['from'] == "gpt" -%}
|
| 40 |
+
{{'<start_of_turn>model\n' + message['value'] + '<end_of_turn>\n'}}
|
| 41 |
+
{%- elif message['from'] == "function_call" -%}
|
| 42 |
+
{%- set func_call = message['value'] -%}
|
| 43 |
+
{{'<start_of_turn>model\n' + 'Action:' + func_call.name + '\n' + 'Action Input:' + func_call.arguments + '<end_of_turn>\n'}}
|
| 44 |
+
{%- elif message['from'] == "observation" -%}
|
| 45 |
+
{{'<start_of_turn>tool\n' + message['value'] + '<end_of_turn>\n'}}
|
| 46 |
+
{%- endif -%}
|
| 47 |
+
{%- endfor -%}
|
| 48 |
+
{%- if add_generation_prompt -%}
|
| 49 |
+
{{'<start_of_turn>model
|
| 50 |
+
'}}
|
| 51 |
+
{%- endif -%}
|
tmp.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python eval_multiturn_textonly.py
|
| 2 |
+
cd /mnt/data-2t/jeff/codes/LLaMA-Factory
|
| 3 |
+
bash train_nav_ctrl.sh
|
training.py
CHANGED
|
@@ -34,7 +34,7 @@ _IGNORE_INDEX = -100
|
|
| 34 |
class BaseAudioDataset(Dataset):
|
| 35 |
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 36 |
self.processor = processor
|
| 37 |
-
self.training = "train" in split
|
| 38 |
self.debug = debug
|
| 39 |
self.sampling_rate = sampling_rate
|
| 40 |
self.name = ""
|
|
@@ -183,7 +183,7 @@ class LibriSpeechDataset(BaseAudioDataset):
|
|
| 183 |
|
| 184 |
return self.prepare_model_inputs(
|
| 185 |
data["audio"]["array"],
|
| 186 |
-
|
| 187 |
answer_text
|
| 188 |
)
|
| 189 |
|
|
@@ -198,11 +198,7 @@ class CommonVoiceDataset(BaseAudioDataset):
|
|
| 198 |
self.lang=source_lang
|
| 199 |
|
| 200 |
# load dataset
|
| 201 |
-
|
| 202 |
-
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 203 |
-
else:
|
| 204 |
-
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 205 |
-
self.data = load_dataset(data_path,
|
| 206 |
source_lang,
|
| 207 |
split=split,
|
| 208 |
trust_remote_code=True,
|
|
@@ -223,63 +219,10 @@ class CommonVoiceDataset(BaseAudioDataset):
|
|
| 223 |
batch["sentence"] = transcription
|
| 224 |
|
| 225 |
return batch
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
import opencc
|
| 229 |
-
converter = opencc.OpenCC('s2tw.json')
|
| 230 |
-
def To_zhTW(batch):
|
| 231 |
-
|
| 232 |
-
transcription = converter.convert(batch["sentence"])
|
| 233 |
-
batch["sentence"] = transcription
|
| 234 |
-
|
| 235 |
-
return batch
|
| 236 |
-
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 237 |
-
if source_lang=='zh-CN':
|
| 238 |
-
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 239 |
-
|
| 240 |
|
| 241 |
# (Optional) Audio length Filtering
|
| 242 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 243 |
-
|
| 244 |
-
if source_lang == "zh-TW" and split=='train':
|
| 245 |
-
import torchaudio
|
| 246 |
-
from torchaudio import transforms
|
| 247 |
-
import copy
|
| 248 |
-
import pickle
|
| 249 |
-
import os
|
| 250 |
-
def subsample(batch):
|
| 251 |
-
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 252 |
-
batch['audio']['sampling_rate']=16000
|
| 253 |
-
return batch
|
| 254 |
-
def TW_data_augment_fast(batch):
|
| 255 |
-
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 256 |
-
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 257 |
-
batch['audio']['array'] = new_array_fast
|
| 258 |
-
return batch
|
| 259 |
-
def TW_data_augment_slow(batch):
|
| 260 |
-
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 261 |
-
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 262 |
-
batch['audio']['array'] = new_array_slow
|
| 263 |
-
return batch
|
| 264 |
-
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 265 |
-
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 266 |
-
if not os.path.exists(fast_path):
|
| 267 |
-
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 268 |
-
with open(fast_path,'wb') as f:
|
| 269 |
-
pickle.dump(data_fast,f)
|
| 270 |
-
else:
|
| 271 |
-
with open(fast_path,'rb') as f:
|
| 272 |
-
data_fast=pickle.load(f)
|
| 273 |
-
|
| 274 |
-
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 275 |
-
if not os.path.exists(slow_path):
|
| 276 |
-
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 277 |
-
with open(slow_path,'wb') as f:
|
| 278 |
-
pickle.dump(data_slow,f)
|
| 279 |
-
else:
|
| 280 |
-
with open(slow_path,'rb') as f:
|
| 281 |
-
data_slow=pickle.load(f)
|
| 282 |
-
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 283 |
|
| 284 |
# Instruction Setting
|
| 285 |
self.instruction = random.choice(INSTRUCTION["asr"])
|
|
@@ -293,7 +236,7 @@ class CommonVoiceDataset(BaseAudioDataset):
|
|
| 293 |
answer_text = data["sentence"]
|
| 294 |
return self.prepare_model_inputs(
|
| 295 |
data["audio"]["array"],
|
| 296 |
-
|
| 297 |
answer_text
|
| 298 |
)
|
| 299 |
|
|
@@ -325,15 +268,6 @@ class FleursDataset(BaseAudioDataset):
|
|
| 325 |
trust_remote_code=True,
|
| 326 |
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 327 |
)
|
| 328 |
-
import opencc
|
| 329 |
-
converter = opencc.OpenCC('s2tw.json')
|
| 330 |
-
def prepare_dataset(batch):
|
| 331 |
-
transcription = converter.convert(batch["transcription"])
|
| 332 |
-
batch["transcription"] = transcription
|
| 333 |
-
|
| 334 |
-
return batch
|
| 335 |
-
if (source_lang=="cmn_hans_cn"):
|
| 336 |
-
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 337 |
|
| 338 |
# (Optional) Audio length Filtering
|
| 339 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
@@ -353,8 +287,7 @@ class FleursDataset(BaseAudioDataset):
|
|
| 353 |
trust_remote_code=True,
|
| 354 |
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 355 |
)
|
| 356 |
-
|
| 357 |
-
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 358 |
source_dict = {item['id']: item for item in self.data}
|
| 359 |
target_dict = {item['id']: item for item in target_data}
|
| 360 |
|
|
@@ -368,11 +301,11 @@ class FleursDataset(BaseAudioDataset):
|
|
| 368 |
|
| 369 |
# Instruction Setting - use target language name
|
| 370 |
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 371 |
-
self.instruction =
|
| 372 |
else:
|
| 373 |
# ASR mode
|
| 374 |
self.lang = source_lang
|
| 375 |
-
self.instruction =
|
| 376 |
|
| 377 |
if self.debug:
|
| 378 |
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
|
@@ -395,7 +328,7 @@ class FleursDataset(BaseAudioDataset):
|
|
| 395 |
|
| 396 |
return self.prepare_model_inputs(
|
| 397 |
audio_array,
|
| 398 |
-
self.instruction.format(self.target_lang_name),
|
| 399 |
answer_text
|
| 400 |
)
|
| 401 |
|
|
@@ -552,9 +485,9 @@ def create_model(model_name_or_path, revision="main", use_flash_attention = Fals
|
|
| 552 |
model = AutoModel.from_pretrained(
|
| 553 |
model_name_or_path,
|
| 554 |
revision=revision,
|
| 555 |
-
torch_dtype=torch.
|
| 556 |
device_map="auto",
|
| 557 |
-
attn_implementation="
|
| 558 |
trust_remote_code=True,
|
| 559 |
)
|
| 560 |
|
|
@@ -564,13 +497,22 @@ def create_model(model_name_or_path, revision="main", use_flash_attention = Fals
|
|
| 564 |
# Freeze all parameters
|
| 565 |
for param in model.parameters():
|
| 566 |
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
-
model.
|
| 569 |
-
model.to(torch.bfloat16)
|
| 570 |
|
| 571 |
# (Optional) unfreeze audio_tower parameters
|
| 572 |
-
|
| 573 |
-
|
| 574 |
|
| 575 |
# Only unfreeze audio_projector parameters
|
| 576 |
for param in model.audio_projector.parameters():
|
|
@@ -654,12 +596,12 @@ _IGNORE_INDEX = -100
|
|
| 654 |
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 655 |
use_flash_attention = True
|
| 656 |
|
| 657 |
-
output_dir = '../
|
| 658 |
-
batch_size =
|
| 659 |
batch_size_per_gpu = 16
|
| 660 |
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 661 |
wd = 0.01
|
| 662 |
-
num_train_epochs =
|
| 663 |
|
| 664 |
revision = "main" #"v1.0"
|
| 665 |
|
|
@@ -677,22 +619,23 @@ model = create_model(
|
|
| 677 |
|
| 678 |
train_datasets = []
|
| 679 |
|
| 680 |
-
#
|
| 681 |
-
|
| 682 |
-
processor=processor,
|
| 683 |
-
source_lang="zh-TW",
|
| 684 |
-
split="other[:70%]"
|
| 685 |
-
)
|
| 686 |
-
train_datasets.append(commonvoice_speech_tw2)
|
| 687 |
-
|
| 688 |
-
commonvoice_speech_cn = CommonVoiceDataset(
|
| 689 |
processor=processor,
|
| 690 |
-
|
| 691 |
-
split="train
|
| 692 |
)
|
| 693 |
-
train_datasets.append(
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
|
|
|
|
| 696 |
commonvoice_speech_tw = CommonVoiceDataset(
|
| 697 |
processor=processor,
|
| 698 |
source_lang="zh-TW",
|
|
@@ -701,17 +644,6 @@ commonvoice_speech_tw = CommonVoiceDataset(
|
|
| 701 |
train_datasets.append(commonvoice_speech_tw)
|
| 702 |
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
# Libri Speech Clean ASR mode (English -> English text)
|
| 707 |
-
libri_speech_clean = LibriSpeechDataset(
|
| 708 |
-
processor=processor,
|
| 709 |
-
subset="clean",
|
| 710 |
-
split="train.360[:50%]"
|
| 711 |
-
)
|
| 712 |
-
train_datasets.append(libri_speech_clean)
|
| 713 |
-
|
| 714 |
-
|
| 715 |
# Fleurs ASR mode (English -> English text)
|
| 716 |
en_asr_fleurs = FleursDataset(
|
| 717 |
processor=processor,
|
|
@@ -722,14 +654,14 @@ en_asr_fleurs = FleursDataset(
|
|
| 722 |
train_datasets.append(en_asr_fleurs)
|
| 723 |
|
| 724 |
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
|
| 734 |
|
| 735 |
|
|
@@ -742,14 +674,14 @@ ch_asr_fleurs = FleursDataset(
|
|
| 742 |
train_datasets.append(ch_asr_fleurs)
|
| 743 |
|
| 744 |
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
|
| 754 |
print("Count Num of Datasets", len(train_datasets))
|
| 755 |
print([len(dataset) for dataset in train_datasets])
|
|
@@ -830,19 +762,19 @@ training_args = TrainingArguments(
|
|
| 830 |
max_grad_norm=1.0,
|
| 831 |
lr_scheduler_type='cosine',
|
| 832 |
warmup_steps=50,
|
| 833 |
-
logging_steps=
|
| 834 |
output_dir=output_dir,
|
| 835 |
save_total_limit=10,
|
| 836 |
save_only_model=True,
|
| 837 |
-
bf16=
|
| 838 |
fp16=False,
|
| 839 |
remove_unused_columns=False,
|
| 840 |
report_to='none',
|
| 841 |
-
deepspeed=dp_config
|
| 842 |
disable_tqdm=False,
|
| 843 |
dataloader_num_workers=4,
|
| 844 |
save_strategy='steps',
|
| 845 |
-
save_steps=
|
| 846 |
ddp_find_unused_parameters=True,
|
| 847 |
|
| 848 |
)
|
|
|
|
| 34 |
class BaseAudioDataset(Dataset):
|
| 35 |
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 36 |
self.processor = processor
|
| 37 |
+
self.training = "train" in split
|
| 38 |
self.debug = debug
|
| 39 |
self.sampling_rate = sampling_rate
|
| 40 |
self.name = ""
|
|
|
|
| 183 |
|
| 184 |
return self.prepare_model_inputs(
|
| 185 |
data["audio"]["array"],
|
| 186 |
+
random.choice(INSTRUCTION["asr"]),
|
| 187 |
answer_text
|
| 188 |
)
|
| 189 |
|
|
|
|
| 198 |
self.lang=source_lang
|
| 199 |
|
| 200 |
# load dataset
|
| 201 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/common_voice_16_1",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
source_lang,
|
| 203 |
split=split,
|
| 204 |
trust_remote_code=True,
|
|
|
|
| 219 |
batch["sentence"] = transcription
|
| 220 |
|
| 221 |
return batch
|
| 222 |
+
self.data.map(prepare_dataset, desc="preprocess dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
# (Optional) Audio length Filtering
|
| 225 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
# Instruction Setting
|
| 228 |
self.instruction = random.choice(INSTRUCTION["asr"])
|
|
|
|
| 236 |
answer_text = data["sentence"]
|
| 237 |
return self.prepare_model_inputs(
|
| 238 |
data["audio"]["array"],
|
| 239 |
+
random.choice(INSTRUCTION["asr"]),
|
| 240 |
answer_text
|
| 241 |
)
|
| 242 |
|
|
|
|
| 268 |
trust_remote_code=True,
|
| 269 |
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 270 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# (Optional) Audio length Filtering
|
| 273 |
self.data = self.filter_by_audio_length(self.data, "audio")
|
|
|
|
| 287 |
trust_remote_code=True,
|
| 288 |
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 289 |
)
|
| 290 |
+
|
|
|
|
| 291 |
source_dict = {item['id']: item for item in self.data}
|
| 292 |
target_dict = {item['id']: item for item in target_data}
|
| 293 |
|
|
|
|
| 301 |
|
| 302 |
# Instruction Setting - use target language name
|
| 303 |
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 304 |
+
self.instruction = INSTRUCTION["ast"]
|
| 305 |
else:
|
| 306 |
# ASR mode
|
| 307 |
self.lang = source_lang
|
| 308 |
+
self.instruction = INSTRUCTION["asr"]
|
| 309 |
|
| 310 |
if self.debug:
|
| 311 |
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
|
|
|
| 328 |
|
| 329 |
return self.prepare_model_inputs(
|
| 330 |
audio_array,
|
| 331 |
+
random.choice(self.instruction).format(self.target_lang_name),
|
| 332 |
answer_text
|
| 333 |
)
|
| 334 |
|
|
|
|
| 485 |
model = AutoModel.from_pretrained(
|
| 486 |
model_name_or_path,
|
| 487 |
revision=revision,
|
| 488 |
+
# torch_dtype=torch.float16,
|
| 489 |
device_map="auto",
|
| 490 |
+
attn_implementation="eager",
|
| 491 |
trust_remote_code=True,
|
| 492 |
)
|
| 493 |
|
|
|
|
| 497 |
# Freeze all parameters
|
| 498 |
for param in model.parameters():
|
| 499 |
param.requires_grad = False
|
| 500 |
+
from peft import LoraConfig, get_peft_model
|
| 501 |
+
lora_config = LoraConfig(
|
| 502 |
+
r=320,
|
| 503 |
+
lora_alpha=32,
|
| 504 |
+
lora_dropout=0.05,
|
| 505 |
+
bias="none",
|
| 506 |
+
task_type="CAUSAL_LM",
|
| 507 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
| 508 |
+
)
|
| 509 |
+
model.language_model.model = get_peft_model(model.language_model.model, lora_config)
|
| 510 |
|
| 511 |
+
# model.to(torch.float16)
|
|
|
|
| 512 |
|
| 513 |
# (Optional) unfreeze audio_tower parameters
|
| 514 |
+
for param in model.audio_tower.parameters():
|
| 515 |
+
param.requires_grad = True
|
| 516 |
|
| 517 |
# Only unfreeze audio_projector parameters
|
| 518 |
for param in model.audio_projector.parameters():
|
|
|
|
| 596 |
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 597 |
use_flash_attention = True
|
| 598 |
|
| 599 |
+
output_dir = '../gemma_tmp'
|
| 600 |
+
batch_size = 16
|
| 601 |
batch_size_per_gpu = 16
|
| 602 |
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 603 |
wd = 0.01
|
| 604 |
+
num_train_epochs = 5
|
| 605 |
|
| 606 |
revision = "main" #"v1.0"
|
| 607 |
|
|
|
|
| 619 |
|
| 620 |
train_datasets = []
|
| 621 |
|
| 622 |
+
# Libri Speech Clean ASR mode (English -> English text)
|
| 623 |
+
libri_speech_clean = LibriSpeechDataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
processor=processor,
|
| 625 |
+
subset="clean",
|
| 626 |
+
split="train.360"
|
| 627 |
)
|
| 628 |
+
train_datasets.append(libri_speech_clean)
|
| 629 |
|
| 630 |
+
# # Libri Speech Other ASR mode (English -> English text)
|
| 631 |
+
# libri_speech_other = LibriSpeechDataset(
|
| 632 |
+
# processor=processor,
|
| 633 |
+
# subset="other",
|
| 634 |
+
# split="train.500"
|
| 635 |
+
# )
|
| 636 |
+
# train_datasets.append(libri_speech_other)
|
| 637 |
|
| 638 |
+
# common voice asr
|
| 639 |
commonvoice_speech_tw = CommonVoiceDataset(
|
| 640 |
processor=processor,
|
| 641 |
source_lang="zh-TW",
|
|
|
|
| 644 |
train_datasets.append(commonvoice_speech_tw)
|
| 645 |
|
| 646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
# Fleurs ASR mode (English -> English text)
|
| 648 |
en_asr_fleurs = FleursDataset(
|
| 649 |
processor=processor,
|
|
|
|
| 654 |
train_datasets.append(en_asr_fleurs)
|
| 655 |
|
| 656 |
|
| 657 |
+
en_ch_ast_fleurs = FleursDataset(
|
| 658 |
+
processor=processor,
|
| 659 |
+
split="train",
|
| 660 |
+
source_lang="en_us",
|
| 661 |
+
target_lang="cmn_hans_cn",
|
| 662 |
+
mode="ast"
|
| 663 |
+
)
|
| 664 |
+
train_datasets.append(en_ch_ast_fleurs)
|
| 665 |
|
| 666 |
|
| 667 |
|
|
|
|
| 674 |
train_datasets.append(ch_asr_fleurs)
|
| 675 |
|
| 676 |
|
| 677 |
+
ch_en_ast_fleurs = FleursDataset(
|
| 678 |
+
processor=processor,
|
| 679 |
+
split="train",
|
| 680 |
+
source_lang="cmn_hans_cn",
|
| 681 |
+
target_lang="en_us",
|
| 682 |
+
mode="ast"
|
| 683 |
+
)
|
| 684 |
+
train_datasets.append(ch_en_ast_fleurs)
|
| 685 |
|
| 686 |
print("Count Num of Datasets", len(train_datasets))
|
| 687 |
print([len(dataset) for dataset in train_datasets])
|
|
|
|
| 762 |
max_grad_norm=1.0,
|
| 763 |
lr_scheduler_type='cosine',
|
| 764 |
warmup_steps=50,
|
| 765 |
+
logging_steps=5,
|
| 766 |
output_dir=output_dir,
|
| 767 |
save_total_limit=10,
|
| 768 |
save_only_model=True,
|
| 769 |
+
bf16=False,
|
| 770 |
fp16=False,
|
| 771 |
remove_unused_columns=False,
|
| 772 |
report_to='none',
|
| 773 |
+
deepspeed=dp_config,
|
| 774 |
disable_tqdm=False,
|
| 775 |
dataloader_num_workers=4,
|
| 776 |
save_strategy='steps',
|
| 777 |
+
save_steps=10000,
|
| 778 |
ddp_find_unused_parameters=True,
|
| 779 |
|
| 780 |
)
|