wojood-api / app.py
TymaaHammouda's picture
Remove print statments
a19fbdf
from fastapi import FastAPI
import torch
import pickle
from huggingface_hub import hf_hub_download, snapshot_download
from Nested.nn.BertSeqTagger import BertSeqTagger
from transformers import AutoTokenizer, AutoModel
import inspect
from collections import namedtuple
from Nested.utils.helpers import load_checkpoint
from Nested.utils.data import get_dataloaders, text2segments
import json
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from IBO_to_XML import IBO_to_XML
from XML_to_HTML import NER_XML_to_HTML
from NER_Distiller import distill_entities
app = FastAPI()
pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
encoder = AutoModel.from_pretrained(pretrained_path).eval()
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
args_path = hf_hub_download(
repo_id="SinaLab/Nested",
filename="args.json"
)
with open(args_path, 'r') as f:
args_data = json.load(f)
# Load model
with open("Nested/utils/tag_vocab.pkl", "rb") as f:
label_vocab = pickle.load(f)
label_vocab = label_vocab[0] # the list loaded from pickle
id2label = {i: s for i, s in enumerate(label_vocab.itos)}
def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
# Split the text into words
words = sentence.split()
# Initialize variables
groups = []
current_group = ""
group_size = 0
# Iterate through the words
for word in words:
if group_size < max_words_per_sentence - 1:
if len(current_group) == 0:
current_group = word
else:
current_group += " " + word
group_size += 1
else:
current_group += " " + word
groups.append(current_group)
current_group = ""
group_size = 0
# Add the last group if it contains less than n words
if current_group:
groups.append(current_group)
return groups
def remove_empty_values(sentences):
return [value for value in sentences if value != '']
def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
separators = []
split_text = [text]
if new_line==True:
separators.append('\n')
if dot==True:
separators.append('.')
if question_mark==True:
separators.append('?')
separators.append('؟')
if exclamation_mark==True:
separators.append('!')
for sep in separators:
new_split_text = []
for part in split_text:
tokens = part.split(sep)
tokens_with_separator = [token + sep for token in tokens[:-1]]
tokens_with_separator.append(tokens[-1].strip())
new_split_text.extend(tokens_with_separator)
split_text = new_split_text
split_text = remove_empty_values(split_text)
return split_text
def jsons_to_list_of_lists(json_list):
return [[d['token'], d['tags']] for d in json_list]
tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
def extract(sentence):
dataset, token_vocab = text2segments(sentence)
vocabs = namedtuple("Vocab", ["tags", "tokens"])
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
dataloader = get_dataloaders(
(dataset,),
vocab,
args_data,
batch_size=32,
shuffle=(False,),
)[0]
segments = tagger.infer(dataloader)
lists = []
for segment in segments:
for token in segment:
item = {}
item["token"] = token.text
list_of_tags = [t["tag"] for t in token.pred_tag]
list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
if not list_of_tags:
item["tags"] = "O"
else:
item["tags"] = " ".join(list_of_tags)
lists.append(item)
return lists
def NER(sentence, mode):
output_list = []
xml = ""
if mode.strip() == "1":
output_list = jsons_to_list_of_lists(extract(sentence))
return output_list
elif mode.strip() == "2":
if output_list != []:
xml = IBO_to_XML(output_list)
return xml
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
return xml
elif mode.strip() == "3":
if xml != "":
html = NER_XML_to_HTML(xml)
return html
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
html = NER_XML_to_HTML(xml)
return html
elif mode.strip() == "4": # json short
if output_list != []:
json_short = distill_entities(output_list)
return json_short
else:
output_list = jsons_to_list_of_lists(extract(sentence))
json_short = distill_entities(output_list)
return json_short
class NERRequest(BaseModel):
text: str
mode: str
@app.post("/predict")
def predict(request: NERRequest):
# Load tagger
text = request.text
mode = request.mode
sentences = sentence_tokenizer(
text, dot=False, new_line=True, question_mark=False, exclamation_mark=False
)
lists = []
for sentence in sentences:
se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300)
for s in se:
output_list = NER(s, mode)
lists.append(output_list)
content = {
"resp": lists,
"statusText": "OK",
"statusCode": 0,
}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)