| import json | |
| from tqdm import tqdm | |
| import argparse | |
| import numpy as np | |
| def save_data(data,file_path): | |
| with open(file_path, 'w', encoding='utf8') as f: | |
| json_data=json.dumps(data,ensure_ascii=False) | |
| f.write(json_data+'\n') | |
| def load_data(file_path,is_training=False): | |
| with open(file_path, 'r', encoding='utf8') as f: | |
| lines = f.readlines() | |
| result=[] | |
| for l,line in tqdm(enumerate(lines)): | |
| data = json.loads(line) | |
| result.append(data) | |
| return result | |
| def recls(line): | |
| mat=[] | |
| for l in line: | |
| s=[v for v in l['score'].values()] | |
| mat.append(s) | |
| mat=np.array(mat) | |
| batch,num_labels=mat.shape | |
| for i in range(len(line)): | |
| index = np.unravel_index(np.argmax(mat, axis=None), mat.shape) | |
| line[index[0]]['label'] = int(index[1]) | |
| mat[index[0],:] = np.zeros((num_labels,)) | |
| mat[:,index[1]] = np.zeros((batch,)) | |
| return line | |
| def chid_m(data): | |
| lines={} | |
| for d in data: | |
| if d['line_id'] not in lines.keys(): | |
| lines[d['line_id']]=[] | |
| lines[d['line_id']].append(d) | |
| result=[] | |
| for k,v in lines.items(): | |
| result.extend(recls(v)) | |
| return result | |
| def submit(file_path): | |
| lines = chid_m(load_data(file_path)) | |
| result={} | |
| for line in tqdm(lines): | |
| data = line | |
| result[data['id']]=data['label'] | |
| return result | |
| if __name__=="__main__": | |
| parser = argparse.ArgumentParser(description="train") | |
| parser.add_argument("--data_path", type=str,default="") | |
| parser.add_argument("--save_path", type=str,default="") | |
| args = parser.parse_args() | |
| save_data(submit(args.data_path), args.save_path) | |