TymaaHammouda commited on
Commit
2b51d25
·
1 Parent(s): 170771d

Add SinaTools and update app file

Browse files
Files changed (2) hide show
  1. app.py +64 -27
  2. requirements.txt +2 -1
app.py CHANGED
@@ -11,6 +11,7 @@ from Nested.utils.data import get_dataloaders, text2segments
11
  import json
12
  from pydantic import BaseModel
13
  from fastapi.responses import JSONResponse
 
14
 
15
  app = FastAPI()
16
  print("Version 2...")
@@ -53,50 +54,86 @@ with open("Nested/utils/tag_vocab.pkl", "rb") as f:
53
  label_vocab = label_vocab[0] # the list loaded from pickle
54
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  class NERRequest(BaseModel):
58
  text: str
59
-
60
 
61
  @app.post("/predict")
62
  def predict(request: NERRequest):
63
-
64
- sentence = request.text # 👈 user input
65
-
66
  # Load tagger
67
  tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
68
 
69
- dataset, token_vocab = text2segments(sentence)
 
70
 
71
- vocabs = namedtuple("Vocab", ["tags", "tokens"])
72
- vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
 
73
 
74
- dataloader = get_dataloaders(
75
- (dataset,),
76
- vocab,
77
- args_data,
78
- batch_size=32,
79
- shuffle=(False,),
80
- )[0]
81
 
82
- segments = tagger.infer(dataloader)
 
83
 
84
- lists = []
 
 
 
 
 
 
85
 
86
- for segment in segments:
87
- for token in segment:
88
- item = {}
89
- item["token"] = token.text
90
 
91
- list_of_tags = [t["tag"] for t in token.pred_tag]
92
- list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
93
 
94
- if not list_of_tags:
95
- item["tags"] = ["O"]
96
- else:
97
- item["tags"] = list_of_tags
 
 
 
 
 
 
 
 
98
 
99
- lists.append(item)
100
 
101
  content = {
102
  "resp": lists,
 
11
  import json
12
  from pydantic import BaseModel
13
  from fastapi.responses import JSONResponse
14
+ from sinatools.utils.tokenizer import sentence_tokenizer
15
 
16
  app = FastAPI()
17
  print("Version 2...")
 
54
  label_vocab = label_vocab[0] # the list loaded from pickle
55
  id2label = {i: s for i, s in enumerate(label_vocab.itos)}
56
 
57
+ def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
58
+ # Split the text into words
59
+ words = simple_word_tokenize(sentence)
60
+
61
+ # Initialize variables
62
+ groups = []
63
+ current_group = ""
64
+ group_size = 0
65
+
66
+ # Iterate through the words
67
+ for word in words:
68
+ if group_size < max_words_per_sentence - 1:
69
+ if len(current_group) == 0:
70
+ current_group = word
71
+ else:
72
+ current_group += " " + word
73
+ group_size += 1
74
+ else:
75
+ current_group += " " + word
76
+ groups.append(current_group)
77
+ current_group = ""
78
+ group_size = 0
79
+
80
+ # Add the last group if it contains less than n words
81
+ if current_group:
82
+ groups.append(current_group)
83
+
84
+ return groups
85
 
86
  class NERRequest(BaseModel):
87
  text: str
88
+ mode: str
89
 
90
  @app.post("/predict")
91
  def predict(request: NERRequest):
 
 
 
92
  # Load tagger
93
  tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
94
 
95
+ text = request.text
96
+ mode = request.mode
97
 
98
+ sentences = sentence_tokenizer(
99
+ text, dot=False, new_line=True, question_mark=False, exclamation_mark=False
100
+ )
101
 
102
+ lists = []
103
+ for sentence in sentences:
104
+ se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300)
105
+ for s in se:
106
+ dataset, token_vocab = text2segments(sentence)
 
 
107
 
108
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
109
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
110
 
111
+ dataloader = get_dataloaders(
112
+ (dataset,),
113
+ vocab,
114
+ args_data,
115
+ batch_size=32,
116
+ shuffle=(False,),
117
+ )[0]
118
 
119
+ segments = tagger.infer(dataloader)
 
 
 
120
 
121
+ # lists = []
 
122
 
123
+ for segment in segments:
124
+ for token in segment:
125
+ item = {}
126
+ item["token"] = token.text
127
+
128
+ list_of_tags = [t["tag"] for t in token.pred_tag]
129
+ list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
130
+
131
+ if not list_of_tags:
132
+ item["tags"] = ["O"]
133
+ else:
134
+ item["tags"] = list_of_tags
135
 
136
+ lists.append(item)
137
 
138
  content = {
139
  "resp": lists,
requirements.txt CHANGED
@@ -5,4 +5,5 @@ numpy
5
  huggingface_hub
6
  transformers
7
  natsort
8
- seqeval
 
 
5
  huggingface_hub
6
  transformers
7
  natsort
8
+ seqeval
9
+ sinatools