FaYo
commited on
Commit
·
0d29d74
1
Parent(s):
64a76db
model
Browse files
dataset/gen_dataset/gen_dataset.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from http import HTTPStatus
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import dashscope
|
| 10 |
+
import requests
|
| 11 |
+
import yaml
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_api_key(api_type, api_yaml_path):
|
| 16 |
+
"""设置 api key
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
api_type (str): api 类型
|
| 20 |
+
api_yaml_path (str): api yaml 文件路径
|
| 21 |
+
"""
|
| 22 |
+
# 读取 yaml 文件
|
| 23 |
+
with open(api_yaml_path, "r", encoding="utf-8") as f:
|
| 24 |
+
api_yaml = yaml.safe_load(f)
|
| 25 |
+
|
| 26 |
+
# 设置 api key
|
| 27 |
+
if api_type == "qwen":
|
| 28 |
+
api_key = api_yaml["ali_qwen_api_key"]
|
| 29 |
+
dashscope.api_key = api_key
|
| 30 |
+
elif api_type == "ernie":
|
| 31 |
+
api_key = api_yaml["baidu_ernie_api_key"]
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("api_type must be qwen or ernie")
|
| 34 |
+
|
| 35 |
+
return api_key
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def call_qwen_message(content_str, model_type=dashscope.Generation.Models.qwen_turbo):
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
response = dashscope.Generation.call(model_type, prompt=content_str)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Maybe connect error , try again : {e}")
|
| 44 |
+
response = dashscope.Generation.call(model_type, prompt=content_str)
|
| 45 |
+
|
| 46 |
+
if response.status_code == HTTPStatus.OK:
|
| 47 |
+
print("Used token: ", response.usage)
|
| 48 |
+
response_str = response.output.text
|
| 49 |
+
else:
|
| 50 |
+
print(
|
| 51 |
+
"Request id: %s, Status code: %s, error code: %s, error message: %s"
|
| 52 |
+
% (
|
| 53 |
+
response.request_id,
|
| 54 |
+
response.status_code,
|
| 55 |
+
response.code,
|
| 56 |
+
response.message,
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
response_str = "Error"
|
| 60 |
+
|
| 61 |
+
return response_str
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def call_ernie_message(content_str, access_token):
|
| 65 |
+
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token={access_token}"
|
| 66 |
+
|
| 67 |
+
payload = json.dumps(
|
| 68 |
+
{
|
| 69 |
+
"messages": [
|
| 70 |
+
{"role": "user", "content": content_str},
|
| 71 |
+
],
|
| 72 |
+
"disable_search": False,
|
| 73 |
+
"enable_citation": False,
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
headers = {"Content-Type": "application/json"}
|
| 77 |
+
|
| 78 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
| 79 |
+
|
| 80 |
+
if response.status_code == HTTPStatus.OK:
|
| 81 |
+
|
| 82 |
+
# 获取 body 中的数据
|
| 83 |
+
response_json = response.json()
|
| 84 |
+
|
| 85 |
+
print("Used token: ", response_json["usage"])
|
| 86 |
+
response_str = response_json["result"]
|
| 87 |
+
else:
|
| 88 |
+
response_str = "Error"
|
| 89 |
+
|
| 90 |
+
return response_str
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def format_json_from_response(func, content_str, func_args, model_name):
|
| 94 |
+
response = func(content_str, func_args)
|
| 95 |
+
|
| 96 |
+
if "```json" in response:
|
| 97 |
+
response = re.findall(r"```json(.*)```", response, flags=re.DOTALL)[0]
|
| 98 |
+
|
| 99 |
+
# 去掉导致 json 格式化失败的字符
|
| 100 |
+
response = response.replace("\\", "\\\\").replace("\n\n", "\n").replace("”", '"').replace("“", '"')
|
| 101 |
+
|
| 102 |
+
if model_name == "qwen":
|
| 103 |
+
# qwen 需要检查文案中是否有 " ,并替换为单引号 '
|
| 104 |
+
|
| 105 |
+
# 查找第一个 output 的字符串
|
| 106 |
+
output_start = response.find('"output": "')
|
| 107 |
+
if output_start != -1:
|
| 108 |
+
# 查找第二个 output 的字符位置
|
| 109 |
+
output_end = response.find("}", output_start + 1)
|
| 110 |
+
if output_end != -1:
|
| 111 |
+
response = list(response)
|
| 112 |
+
# 截取第二个 output 的字符串
|
| 113 |
+
check_len = len(response[output_start + len('"output": "') : output_end - 10])
|
| 114 |
+
for idx in range(check_len):
|
| 115 |
+
str_idx = output_start + len('"output": "') + idx
|
| 116 |
+
if response[str_idx] == '"':
|
| 117 |
+
response[str_idx] = "'"
|
| 118 |
+
|
| 119 |
+
response = "".join(response)
|
| 120 |
+
|
| 121 |
+
# 加上 strict=False 解决 decode Invalid control character
|
| 122 |
+
format_json = json.loads(response, strict=False)
|
| 123 |
+
|
| 124 |
+
return format_json, response
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def process_request(func, content_str, func_args, model_name):
|
| 128 |
+
"""_summary_
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
func (_type_): _description_
|
| 132 |
+
content_str (_type_): _description_
|
| 133 |
+
func_args (str):
|
| 134 |
+
qwen: model_type
|
| 135 |
+
ernie: api_key
|
| 136 |
+
Returns:
|
| 137 |
+
_type_: _description_
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
format_json, response = format_json_from_response(func, content_str, func_args, model_name)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
try:
|
| 144 |
+
# 再试一次
|
| 145 |
+
print(f"\n Got error, try again <== {e} \n")
|
| 146 |
+
if isinstance(e, json.decoder.JSONDecodeError):
|
| 147 |
+
print(f"JSONDecodeError doc 1: {str(e.doc)} \n")
|
| 148 |
+
format_json, response = format_json_from_response(func, content_str, func_args, model_name)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"\n Got error <== {e} \n")
|
| 151 |
+
if isinstance(e, json.decoder.JSONDecodeError):
|
| 152 |
+
print(f"JSONDecodeError doc 2: {str(e.doc)} \n")
|
| 153 |
+
with open(f"error-{model_name}.log", "a+", encoding="utf-8") as f_error:
|
| 154 |
+
if isinstance(e, json.decoder.JSONDecodeError):
|
| 155 |
+
f_error.write(f"JSONDecodeError doc: {str(e.doc)} \n")
|
| 156 |
+
f_error.write(str(e))
|
| 157 |
+
f_error.flush()
|
| 158 |
+
|
| 159 |
+
format_json = {"Error": "Error"}
|
| 160 |
+
|
| 161 |
+
return format_json
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def gen_product_highlights(dastset_yaml_path, api_yaml_path):
|
| 165 |
+
"""根据产品的 yaml 文件生成每个产品的特点描述
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
dastset_yaml_path (str): 数据集的 yaml 文件路径
|
| 169 |
+
api_yaml_path (_type_): api 的 yaml 文件路径
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
# 读取 yaml 文件
|
| 173 |
+
with open(dastset_yaml_path, "r", encoding="utf-8") as f:
|
| 174 |
+
dataset_yaml = yaml.safe_load(f)
|
| 175 |
+
|
| 176 |
+
set_api_key("qwen", api_yaml_path)
|
| 177 |
+
|
| 178 |
+
for _, products in dataset_yaml["product_list"].items():
|
| 179 |
+
for product_class, product in products.items():
|
| 180 |
+
product_str = str(product).replace("'", "")
|
| 181 |
+
print(f"Process: {product_str}")
|
| 182 |
+
|
| 183 |
+
product_highlights = call_qwen_message(
|
| 184 |
+
content_str=product_str,
|
| 185 |
+
system_str="现在你精通医院里的各种事物,你帮我举例每个科室中五个细分专业治疗方法中每个细分治疗方法的六个优势或者特点,然后用python-dic的形式输出:{类名:[特点1,特点2,...]},去掉1,2的字样,除python字典外的其他都不要输出,不要有任何的警告信息",
|
| 186 |
+
model_type=dashscope.Generation.Models.qwen_turbo,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
code_block = re.findall(r"```python(.*)```", product_highlights, flags=re.DOTALL)[0]
|
| 190 |
+
if " = " in code_block[:20]:
|
| 191 |
+
code_block = code_block.split(" = ")[1]
|
| 192 |
+
|
| 193 |
+
products[product_class] = eval(re.findall(r"```python(.*)```", product_highlights, flags=re.DOTALL)[0])
|
| 194 |
+
|
| 195 |
+
# 保存 yaml 文件
|
| 196 |
+
with open(f"{dastset_yaml_path}", "w", encoding="utf-8") as f:
|
| 197 |
+
yaml.dump(dataset_yaml, f, allow_unicode=True)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def gen_dataset(dastset_yaml_path: str, api_yaml_path: str, save_json_root: Path, model_name: str, specific_name=""):
|
| 201 |
+
|
| 202 |
+
# 确保文件夹存在
|
| 203 |
+
save_json_root.mkdir(parents=True, exist_ok=True)
|
| 204 |
+
|
| 205 |
+
# 读取 yaml 文件
|
| 206 |
+
with open(dastset_yaml_path, "r", encoding="utf-8") as f:
|
| 207 |
+
dataset_yaml = yaml.safe_load(f)
|
| 208 |
+
|
| 209 |
+
if specific_name != "":
|
| 210 |
+
assert (
|
| 211 |
+
specific_name in dataset_yaml["role_type"]
|
| 212 |
+
), f"{specific_name} not in dataset_yaml['role_type'] ({dataset_yaml['role_type']}), pls check dataset yaml!"
|
| 213 |
+
|
| 214 |
+
# 设置 api key
|
| 215 |
+
api_key = set_api_key(model_name, api_yaml_path)
|
| 216 |
+
|
| 217 |
+
data_gen_setting = dataset_yaml["data_generation_setting"]
|
| 218 |
+
gen_num = data_gen_setting["each_product_gen"]
|
| 219 |
+
each_pick_hightlight = data_gen_setting["each_pick_hightlight"]
|
| 220 |
+
each_pick_question = data_gen_setting["each_pick_question"]
|
| 221 |
+
|
| 222 |
+
# qwen 配置调取的模型种类,确保有个一是最强模型
|
| 223 |
+
# gen_model_type = [dashscope.Generation.Models.qwen_plus] * (gen_num - 2)
|
| 224 |
+
# gen_model_type += [dashscope.Generation.Models.qwen_max] * 2
|
| 225 |
+
qwen_model_type = [dashscope.Generation.Models.qwen_max] * gen_num
|
| 226 |
+
|
| 227 |
+
for role_type, role_character in dataset_yaml["role_type"].items():
|
| 228 |
+
|
| 229 |
+
if specific_name != "" and role_type != specific_name:
|
| 230 |
+
# 只生成特定人物的
|
| 231 |
+
print(f"specific_name = {specific_name}, skipping for {role_type}")
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
gen_json = dict()
|
| 235 |
+
|
| 236 |
+
save_json_path = save_json_root.joinpath(f"{model_name}_{role_type}_train.json")
|
| 237 |
+
bk_json_path = save_json_root.joinpath(f"{model_name}_{role_type}_train.json.bk")
|
| 238 |
+
|
| 239 |
+
# 加载之前已经有的 json
|
| 240 |
+
if save_json_path.exists():
|
| 241 |
+
with open(save_json_path, "r", encoding="utf-8") as f:
|
| 242 |
+
gen_json = json.load(f)
|
| 243 |
+
|
| 244 |
+
# 加载成功的话,再删除备份的 json
|
| 245 |
+
if bk_json_path.exists():
|
| 246 |
+
bk_json_path.unlink()
|
| 247 |
+
|
| 248 |
+
# 遍历所有产品,方便进度条显示
|
| 249 |
+
list_product = [
|
| 250 |
+
product_name
|
| 251 |
+
for _, products in dataset_yaml["product_list"].items()
|
| 252 |
+
for _, product_name_list in products.items()
|
| 253 |
+
for product_name in product_name_list.keys()
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
# 生成人物性格
|
| 257 |
+
character = "、".join(role_character)
|
| 258 |
+
|
| 259 |
+
pbar = tqdm(total=len(list_product))
|
| 260 |
+
|
| 261 |
+
# 遍历产品
|
| 262 |
+
for _, products in dataset_yaml["product_list"].items():
|
| 263 |
+
for _, product_name_list in products.items():
|
| 264 |
+
for product, hightlights in product_name_list.items():
|
| 265 |
+
pbar.set_description(product)
|
| 266 |
+
|
| 267 |
+
if product in gen_json:
|
| 268 |
+
# 跳过已经有的
|
| 269 |
+
pbar.update(1)
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
gen_json.update({product: []})
|
| 273 |
+
|
| 274 |
+
# 生成数据
|
| 275 |
+
for idx in range(gen_num):
|
| 276 |
+
|
| 277 |
+
# 随机抽取 ${each_pick_hightlight} 个产品特性
|
| 278 |
+
if each_pick_hightlight >= len(hightlights):
|
| 279 |
+
# 超过打乱,增加随机性
|
| 280 |
+
hightlights_list = random.shuffle(hightlights)
|
| 281 |
+
else:
|
| 282 |
+
hightlights_list = random.sample(hightlights, each_pick_hightlight)
|
| 283 |
+
hightlight_str = "、".join(hightlights_list)
|
| 284 |
+
|
| 285 |
+
# 随机抽取 ${each_pick_question} 个提问角度
|
| 286 |
+
if each_pick_question >= len(dataset_yaml["customer_question_type"]):
|
| 287 |
+
# 超过打乱,增加随机性
|
| 288 |
+
customer_question_type = random.shuffle(dataset_yaml["customer_question_type"])
|
| 289 |
+
else:
|
| 290 |
+
customer_question_type = random.sample(dataset_yaml["customer_question_type"], each_pick_question)
|
| 291 |
+
customer_question_str = "、".join(customer_question_type)
|
| 292 |
+
|
| 293 |
+
# 商品信息
|
| 294 |
+
product_info_str = dataset_yaml["product_info_struct"][0].replace("{name}", product)
|
| 295 |
+
product_info_str += dataset_yaml["product_info_struct"][1].replace("{highlights}", hightlight_str)
|
| 296 |
+
|
| 297 |
+
content_str = (
|
| 298 |
+
data_gen_setting["dataset_gen_prompt"]
|
| 299 |
+
.replace("{role_type}", role_type)
|
| 300 |
+
.replace("{character}", character)
|
| 301 |
+
.replace("{product_info}", product_info_str)
|
| 302 |
+
.replace("{customer_question}", customer_question_str)
|
| 303 |
+
.replace("{each_conversation_qa}", str(data_gen_setting["each_conversation_qa"]))
|
| 304 |
+
.replace(
|
| 305 |
+
"{dataset_json_format}",
|
| 306 |
+
data_gen_setting["dataset_json_format"].replace("{product_info}", product_info_str),
|
| 307 |
+
)
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
print(f"\n Resquest [ {model_name} ] {idx + 1}/{gen_num} ==> {content_str} \n")
|
| 311 |
+
if model_name == "qwen":
|
| 312 |
+
format_json = process_request(call_qwen_message, content_str, qwen_model_type[idx], model_name)
|
| 313 |
+
elif model_name == "ernie":
|
| 314 |
+
format_json = process_request(call_ernie_message, content_str, api_key, model_name)
|
| 315 |
+
else:
|
| 316 |
+
raise ValueError(f"model_name {model_name} not support")
|
| 317 |
+
|
| 318 |
+
if "conversation" in format_json and len(format_json["conversation"]) > 0:
|
| 319 |
+
|
| 320 |
+
# 第一个结果因为节省 token,需要将 system 和 input 放回去
|
| 321 |
+
conversation_setting = deepcopy(dataset_yaml["conversation_setting"])
|
| 322 |
+
system_str = (
|
| 323 |
+
conversation_setting["system"].replace("{role_type}", role_type).replace("{character}", character)
|
| 324 |
+
)
|
| 325 |
+
input_str = conversation_setting["first_input"].replace("{product_info}", product_info_str)
|
| 326 |
+
|
| 327 |
+
# 将第一个对话加入必要信息
|
| 328 |
+
format_json["conversation"][0] = {
|
| 329 |
+
"system": system_str,
|
| 330 |
+
"input": input_str,
|
| 331 |
+
"output": format_json["conversation"][0]["output"],
|
| 332 |
+
}
|
| 333 |
+
else:
|
| 334 |
+
format_json = {"Error": "Error"}
|
| 335 |
+
|
| 336 |
+
print(f"\n Response [ {model_name} ] {idx + 1}/{gen_num} <== {format_json} \n")
|
| 337 |
+
gen_json[product].append(format_json)
|
| 338 |
+
|
| 339 |
+
pbar.update(1)
|
| 340 |
+
|
| 341 |
+
# 备份旧的
|
| 342 |
+
if save_json_path.exists():
|
| 343 |
+
save_json_path.rename(bk_json_path)
|
| 344 |
+
|
| 345 |
+
# 保存 json
|
| 346 |
+
with open(save_json_path, "w", encoding="utf-8") as f:
|
| 347 |
+
json.dump(gen_json, f, indent=4, ensure_ascii=False)
|
| 348 |
+
|
| 349 |
+
# 如果保存成功,删掉旧的
|
| 350 |
+
if bk_json_path.exists():
|
| 351 |
+
bk_json_path.unlink()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if __name__ == "__main__":
|
| 355 |
+
|
| 356 |
+
# 例子:全部人物使用 Qwen api 生成数据
|
| 357 |
+
# cd /path/to/Streamer-Sales/dataset/gen_dataset
|
| 358 |
+
# python gen_dataset.py qwen
|
| 359 |
+
|
| 360 |
+
# 命令行输入参数
|
| 361 |
+
parser = argparse.ArgumentParser(description="Gen Dataset")
|
| 362 |
+
parser.add_argument("model_name", type=str, choices=["qwen", "ernie"], help="Model name for data generation")
|
| 363 |
+
parser.add_argument("--data_yaml", type=str, default="../../configs/conversation_cfg.yaml", help="data setting file path")
|
| 364 |
+
parser.add_argument("--api_yaml", type=str, default="../../configs/api_cfg.yaml", help="api setting file path")
|
| 365 |
+
parser.add_argument("--output_dir", type=str, default="./train_dataset/response", help="generation json output dir")
|
| 366 |
+
parser.add_argument("--specific_name", type=str, default="", help="Character name for data generation")
|
| 367 |
+
args = parser.parse_args()
|
| 368 |
+
|
| 369 |
+
# 生成产品特性(可选)
|
| 370 |
+
# gen_product_highlights(args.data_yaml, args.api_yaml)
|
| 371 |
+
|
| 372 |
+
# 生成对话数据集
|
| 373 |
+
gen_dataset(
|
| 374 |
+
args.data_yaml, args.api_yaml, Path(args.output_dir), model_name=args.model_name, specific_name=args.specific_name
|
| 375 |
+
)
|
dataset/gen_dataset/merge_dataset.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def gen_self_self_aware_dataset():
|
| 8 |
+
|
| 9 |
+
# 自我认知
|
| 10 |
+
self_aware_question = [
|
| 11 |
+
"你好",
|
| 12 |
+
"你是谁",
|
| 13 |
+
"你叫什么名字",
|
| 14 |
+
"请做一下自我介绍",
|
| 15 |
+
"介绍下你自己",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
self_aware_answer_lelemiao = [
|
| 19 |
+
"您好,我是智能医导,随时准备解答您的医疗疑问。",
|
| 20 |
+
"您好,我是智能医导,助您轻松就医。",
|
| 21 |
+
"您好,我是智能医导,提供专业医疗指导。",
|
| 22 |
+
"您好,我是智能医导,解答您的健康疑惑。",
|
| 23 |
+
"您好,我是智能医导,帮助您了解医疗服务。",
|
| 24 |
+
"您好,我是智能医导,您的医疗问题助手。",
|
| 25 |
+
"您好,我是智能医导,助您快速获取医疗信息。",
|
| 26 |
+
"您好,我是智能医导,为您提供医疗解答。",
|
| 27 |
+
"您好,我是智能医导,帮助您理解医疗流程。",
|
| 28 |
+
"您好,我是智能医导,解答您的医疗咨询。",
|
| 29 |
+
"您好,我是智能医导,助您掌握健康知识。",
|
| 30 |
+
"您好,我是智能医导,提供医疗信息查询。",
|
| 31 |
+
"您好,我是智能医导,助您解决就医难题。",
|
| 32 |
+
"您好,我是智能医导,您的私人医疗顾问。",
|
| 33 |
+
"您好,我是智能医导,随时为您提供帮助。",
|
| 34 |
+
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
self_aware_json = []
|
| 38 |
+
for anser in self_aware_answer_lelemiao:
|
| 39 |
+
|
| 40 |
+
self_aware_json.append({"conversation": [{"input": random.choice(self_aware_question), "output": anser}]})
|
| 41 |
+
|
| 42 |
+
return self_aware_json
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def merge_dataset(save_json_root: Path, final_save_json_path: Path):
|
| 46 |
+
# 将两个 json 进行合并
|
| 47 |
+
json_list = []
|
| 48 |
+
for json_path in save_json_root.glob("*.json"):
|
| 49 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 50 |
+
json_list.append(json.load(f))
|
| 51 |
+
|
| 52 |
+
filter_json_list = []
|
| 53 |
+
|
| 54 |
+
dirty_conversion = []
|
| 55 |
+
for model_name in json_list:
|
| 56 |
+
for product_name, gen_data_list in model_name.items():
|
| 57 |
+
|
| 58 |
+
for gen_data in gen_data_list:
|
| 59 |
+
if isinstance(gen_data, dict) and "Error" in gen_data.keys():
|
| 60 |
+
print(f"Got error data in {product_name}")
|
| 61 |
+
dirty_conversion.append(gen_data)
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# 洗掉一些没有 input 的数据
|
| 65 |
+
sub_filter_list = {"conversation": []}
|
| 66 |
+
for sub_list in gen_data["conversation"]:
|
| 67 |
+
|
| 68 |
+
# 剔除不合适的 key
|
| 69 |
+
accept_keys = ["input", "output", "system"]
|
| 70 |
+
sub_list = {key: value for key, value in sub_list.items() if key in accept_keys}
|
| 71 |
+
|
| 72 |
+
if len(sub_list.keys()) < 2:
|
| 73 |
+
# 如果只有单个 input output 出现,跳过
|
| 74 |
+
dirty_conversion.append(sub_list)
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
if "input" not in sub_list or "output" not in sub_list:
|
| 78 |
+
# 如果没有 input 或者 output,跳过
|
| 79 |
+
dirty_conversion.append(sub_list)
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
sub_filter_list["conversation"].append(sub_list)
|
| 83 |
+
|
| 84 |
+
if len(sub_filter_list["conversation"]) > 0:
|
| 85 |
+
filter_json_list.append(sub_filter_list)
|
| 86 |
+
|
| 87 |
+
# 修复数据集
|
| 88 |
+
for idx in range(len(filter_json_list)):
|
| 89 |
+
filter_json_list[idx]["conversation"][0][
|
| 90 |
+
"system"
|
| 91 |
+
] = "现在你是一位医院大厅里的智能医导小助手,你的名字叫智能医导小助手,你的说话方式是严肃端庄。你能够根据病人的需求提供专业的医疗咨询,并且结合医疗知识解答用户提出的各种健康相关疑问。"
|
| 92 |
+
|
| 93 |
+
# 生成自我认知的数据
|
| 94 |
+
filter_json_list += gen_self_self_aware_dataset()
|
| 95 |
+
|
| 96 |
+
# 保存
|
| 97 |
+
with open(
|
| 98 |
+
final_save_json_path.parent.joinpath(f"{len(filter_json_list)}_{final_save_json_path.name}"), "w", encoding="utf-8"
|
| 99 |
+
) as f:
|
| 100 |
+
json.dump(filter_json_list, f, ensure_ascii=False, indent=4)
|
| 101 |
+
|
| 102 |
+
if len(dirty_conversion) > 0:
|
| 103 |
+
# 保存错误的过滤数据,方便用户自行解决
|
| 104 |
+
with open(final_save_json_path.parent.joinpath(f"error_{final_save_json_path.name}"), "w", encoding="utf-8") as f:
|
| 105 |
+
json.dump(dirty_conversion, f, ensure_ascii=False, indent=4)
|
| 106 |
+
|
| 107 |
+
sum_input_output_count = 0
|
| 108 |
+
for conversion in filter_json_list:
|
| 109 |
+
sum_input_output_count += len(conversion["conversation"])
|
| 110 |
+
print(
|
| 111 |
+
f"总生成有效 conversion 数据 {len(filter_json_list)} 组,内含 {sum_input_output_count} 条对话,剔除脏对话 {len(dirty_conversion)} 条,保存到 error_{final_save_json_path.name} 中。"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
# 命令行输入参数
|
| 117 |
+
# TODO 目前仅仅支持 乐乐喵
|
| 118 |
+
parser = argparse.ArgumentParser(description="Merge Dataset")
|
| 119 |
+
parser.add_argument("data_root", type=str, help="path to response dir")
|
| 120 |
+
parser.add_argument("output_path", type=str, help="path to response dir")
|
| 121 |
+
args = parser.parse_args()
|
| 122 |
+
|
| 123 |
+
save_json_root = Path(args.data_root)
|
| 124 |
+
final_save_json_path = Path(args.output_path)
|
| 125 |
+
merge_dataset(save_json_root, final_save_json_path)
|
dataset/gen_dataset/train_dataset/90_train.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/gen_dataset/train_dataset/response/qwen_智能医导小助手_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|