Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,21 +25,20 @@ all_datasets = get_datasets()
|
|
| 25 |
|
| 26 |
|
| 27 |
#def get_split(dataset_name):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
'''
|
| 43 |
|
| 44 |
def get_model(datasetname):
|
| 45 |
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
|
@@ -50,195 +49,195 @@ def get_model(datasetname):
|
|
| 50 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 51 |
return model
|
| 52 |
|
| 53 |
-
|
| 54 |
-
def get_tokenizer(datasetname):
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
sacrebleu = load_metric('sacrebleu')
|
| 64 |
-
rouge = load_metric('rouge')
|
| 65 |
-
meteor = load_metric('meteor')
|
| 66 |
-
bertscore = load_metric('bertscore')
|
| 67 |
-
|
| 68 |
-
# use gpu if it's available
|
| 69 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 70 |
-
|
| 71 |
-
MAX_INPUT_LENGTH = 256
|
| 72 |
-
MAX_TARGET_LENGTH = 128
|
| 73 |
-
|
| 74 |
-
def preprocess_function(examples, **kwargs):
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def flatten_list(l):
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def extract_feedback(predictions):
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def extract_labels(predictions):
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def get_predictions_labels(model, dataloader, tokenizer):
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
| 199 |
|
| 200 |
|
| 201 |
-
def load_data():
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
|
| 216 |
-
|
| 217 |
|
| 218 |
-
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
|
| 243 |
def get_rows(datasetname):
|
| 244 |
if datasetname == "Communication Networks: unseen questions":
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
#def get_split(dataset_name):
|
| 28 |
+
# if dataset_name == "Communication Networks: unseen questions":
|
| 29 |
+
# split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
| 30 |
+
# if dataset_name == "Communication Networks: unseen answers":
|
| 31 |
+
# split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
| 32 |
+
# if dataset_name == "Micro Job: unseen questions":
|
| 33 |
+
# split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
| 34 |
+
# if dataset_name == "Micro Job: unseen answers":
|
| 35 |
+
# split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
| 36 |
+
# if dataset_name == "Legal Domain: unseen questions":
|
| 37 |
+
# split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
| 38 |
+
# if dataset_name == "Legal Domain: unseen answers":
|
| 39 |
+
# split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
| 40 |
+
# return split
|
| 41 |
+
|
|
|
|
| 42 |
|
| 43 |
def get_model(datasetname):
|
| 44 |
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
|
|
|
| 49 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 50 |
return model
|
| 51 |
|
| 52 |
+
|
| 53 |
+
# def get_tokenizer(datasetname):
|
| 54 |
+
# if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
| 55 |
+
# tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
| 56 |
+
# if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
| 57 |
+
# tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
| 58 |
+
# if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
| 59 |
+
# tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 60 |
+
# return tokenizer
|
| 61 |
+
|
| 62 |
+
# sacrebleu = load_metric('sacrebleu')
|
| 63 |
+
# rouge = load_metric('rouge')
|
| 64 |
+
# meteor = load_metric('meteor')
|
| 65 |
+
# bertscore = load_metric('bertscore')
|
| 66 |
+
|
| 67 |
+
# # use gpu if it's available
|
| 68 |
+
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 69 |
+
|
| 70 |
+
# MAX_INPUT_LENGTH = 256
|
| 71 |
+
# MAX_TARGET_LENGTH = 128
|
| 72 |
+
|
| 73 |
+
# def preprocess_function(examples, **kwargs):
|
| 74 |
+
# """
|
| 75 |
+
# Preprocess entries of the given dataset
|
| 76 |
+
|
| 77 |
+
# Params:
|
| 78 |
+
# examples (Dataset): dataset to be preprocessed
|
| 79 |
+
# Returns:
|
| 80 |
+
# model_inputs (BatchEncoding): tokenized dataset entries
|
| 81 |
+
# """
|
| 82 |
+
|
| 83 |
+
# inputs, targets = [], []
|
| 84 |
+
# for i in range(len(examples['question'])):
|
| 85 |
+
# inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
|
| 86 |
+
# targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
|
| 87 |
+
|
| 88 |
+
# # apply tokenization to inputs and labels
|
| 89 |
+
# tokenizer = kwargs["tokenizer"]
|
| 90 |
+
# model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
|
| 91 |
+
# labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
|
| 92 |
+
|
| 93 |
+
# model_inputs['labels'] = labels['input_ids']
|
| 94 |
+
|
| 95 |
+
# return model_inputs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# def flatten_list(l):
|
| 100 |
+
# """
|
| 101 |
+
# Utility function to convert a list of lists into a flattened list
|
| 102 |
+
# Params:
|
| 103 |
+
# l (list of lists): list to be flattened
|
| 104 |
+
# Returns:
|
| 105 |
+
# A flattened list with the elements of the original list
|
| 106 |
+
# """
|
| 107 |
+
# return [item for sublist in l for item in sublist]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# def extract_feedback(predictions):
|
| 111 |
+
# """
|
| 112 |
+
# Utility function to extract the feedback from the predictions of the model
|
| 113 |
+
# Params:
|
| 114 |
+
# predictions (list): complete model predictions
|
| 115 |
+
# Returns:
|
| 116 |
+
# feedback (list): extracted feedback from the model's predictions
|
| 117 |
+
# """
|
| 118 |
+
# feedback = []
|
| 119 |
+
# # iterate through predictions and try to extract predicted feedback
|
| 120 |
+
# for pred in predictions:
|
| 121 |
+
# try:
|
| 122 |
+
# fb = pred.split(':', 1)[1]
|
| 123 |
+
# except IndexError:
|
| 124 |
+
# try:
|
| 125 |
+
# if pred.lower().startswith('partially correct'):
|
| 126 |
+
# fb = pred.split(' ', 1)[2]
|
| 127 |
+
# else:
|
| 128 |
+
# fb = pred.split(' ', 1)[1]
|
| 129 |
+
# except IndexError:
|
| 130 |
+
# fb = pred
|
| 131 |
+
# feedback.append(fb.strip())
|
| 132 |
|
| 133 |
+
# return feedback
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# def extract_labels(predictions):
|
| 137 |
+
# """
|
| 138 |
+
# Utility function to extract the labels from the predictions of the model
|
| 139 |
+
# Params:
|
| 140 |
+
# predictions (list): complete model predictions
|
| 141 |
+
# Returns:
|
| 142 |
+
# feedback (list): extracted labels from the model's predictions
|
| 143 |
+
# """
|
| 144 |
+
# labels = []
|
| 145 |
+
# for pred in predictions:
|
| 146 |
+
# if pred.lower().startswith('correct'):
|
| 147 |
+
# label = 'Correct'
|
| 148 |
+
# elif pred.lower().startswith('partially correct'):
|
| 149 |
+
# label = 'Partially correct'
|
| 150 |
+
# elif pred.lower().startswith('incorrect'):
|
| 151 |
+
# label = 'Incorrect'
|
| 152 |
+
# else:
|
| 153 |
+
# label = 'Unknown label'
|
| 154 |
+
# labels.append(label)
|
| 155 |
|
| 156 |
+
# return labels
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# def get_predictions_labels(model, dataloader, tokenizer):
|
| 160 |
+
# """
|
| 161 |
+
# Evaluate model on the given dataset
|
| 162 |
+
|
| 163 |
+
# Params:
|
| 164 |
+
# model (PreTrainedModel): seq2seq model
|
| 165 |
+
# dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
|
| 166 |
+
# Returns:
|
| 167 |
+
# results (dict): dictionary with the computed evaluation metrics
|
| 168 |
+
# predictions (list): list of the decoded predictions of the model
|
| 169 |
+
# """
|
| 170 |
+
# decoded_preds, decoded_labels = [], []
|
| 171 |
+
|
| 172 |
+
# model.eval()
|
| 173 |
+
# # iterate through batchs in the dataloader
|
| 174 |
+
# for batch in tqdm(dataloader):
|
| 175 |
+
# with torch.no_grad():
|
| 176 |
+
# batch = {k: v.to(device) for k, v in batch.items()}
|
| 177 |
+
# # generate tokens from batch
|
| 178 |
+
# generated_tokens = model.generate(
|
| 179 |
+
# batch['input_ids'],
|
| 180 |
+
# attention_mask=batch['attention_mask'],
|
| 181 |
+
# max_length=MAX_TARGET_LENGTH
|
| 182 |
+
# )
|
| 183 |
+
# # get golden labels from batch
|
| 184 |
+
# labels_batch = batch['labels']
|
| 185 |
|
| 186 |
+
# # decode model predictions and golden labels
|
| 187 |
+
# decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 188 |
+
# decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
|
| 189 |
|
| 190 |
+
# decoded_preds.append(decoded_preds_batch)
|
| 191 |
+
# decoded_labels.append(decoded_labels_batch)
|
| 192 |
|
| 193 |
+
# # convert predictions and golden labels into flattened lists
|
| 194 |
+
# predictions = flatten_list(decoded_preds)
|
| 195 |
+
# labels = flatten_list(decoded_labels)
|
| 196 |
|
| 197 |
+
# return predictions, labels
|
| 198 |
|
| 199 |
|
| 200 |
+
# def load_data():
|
| 201 |
+
# df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
|
| 202 |
+
# for ds in all_datasets:
|
| 203 |
+
# split = get_split(ds)
|
| 204 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
|
| 205 |
+
# tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
|
| 206 |
|
| 207 |
+
# processed_dataset = split.map(
|
| 208 |
+
# preprocess_function,
|
| 209 |
+
# batched=True,
|
| 210 |
+
# remove_columns=split.column_names,
|
| 211 |
+
# fn_kwargs={"tokenizer": tokenizer}
|
| 212 |
+
# )
|
| 213 |
+
# processed_dataset.set_format('torch')
|
| 214 |
|
| 215 |
+
# dataloader = DataLoader(processed_dataset, batch_size=4)
|
| 216 |
|
| 217 |
+
# predictions, labels = get_predictions_labels(model, dataloader, tokenizer)
|
| 218 |
|
| 219 |
+
# predicted_feedback = extract_feedback(predictions)
|
| 220 |
+
# predicted_labels = extract_labels(predictions)
|
| 221 |
|
| 222 |
+
# reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
| 223 |
+
# reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
| 224 |
|
| 225 |
+
# rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
| 226 |
+
# bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
| 227 |
+
# meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
|
| 228 |
+
# bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
|
| 229 |
|
| 230 |
+
# reference_labels_np = np.array(reference_labels)
|
| 231 |
+
# accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
|
| 232 |
+
# f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
|
| 233 |
+
# f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
|
| 234 |
|
| 235 |
+
# new_row_data = {"Model": get_model(ds), "Dataset": ds, "SacreBLEU": bleu_score, "ROUGE-2": rouge_score, "METEOR": meteor_score, "BERTScore": bert_score, "Accuracy": accuracy_value, "Weighted F1": f1_weighted_value, "Macro F1": f1_macro_value}
|
| 236 |
+
# new_row = pd.DataFrame(new_row_data)
|
| 237 |
|
| 238 |
+
# df = pd.concat([df, new_row])
|
| 239 |
+
# return df
|
| 240 |
+
|
| 241 |
|
| 242 |
def get_rows(datasetname):
|
| 243 |
if datasetname == "Communication Networks: unseen questions":
|