Az-r-ow commited on
Commit
04fc5d3
·
1 Parent(s): dcf66c8

feat(data_processing): Sentence processing function to extract logits and labels from a sentence

Browse files
app/travel_resolver/libs/nlp/data_processing.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk, re
2
+
3
+
4
+ def get_tagged_content(sentence: str, tag: str) -> str:
5
+ """
6
+ Extract the content between two tags in a sentence given the tag.
7
+
8
+ Args:
9
+ sentence (str): The sentence to extract the content from.
10
+ tag (str): The tag to extract the content between.
11
+
12
+ Returns:
13
+ str: The content between the tags.
14
+
15
+ Example:
16
+ >>> get_tagged_content("Je voudrais voyager de <Dep>Nice<Dep> à <Arr>Clermont Ferrand<Arr>.", "<Dep>")
17
+ "Nice"
18
+ """
19
+ tag_match = re.search(rf"{tag}(.*?){tag}", sentence)
20
+ if tag_match:
21
+ return tag_match.group(1)
22
+ return None
23
+
24
+
25
+ def process_sentence(sentence: str, dep_token="<Dep>", arr_token="<Arr>") -> tuple:
26
+ """
27
+ Given a sentence, extract the departure and arrival locations and tokenize the sentence.
28
+ Then assign labels to the tokens based on whether they are part of the departure or arrival locations.
29
+ Finally, return the logits and labels will be returned.
30
+
31
+ Args:
32
+ sentence (str): The sentence to process.
33
+ dep_token (str): The token to mark the departure location.
34
+ arr_token (str): The token to mark the arrival location.
35
+
36
+ Returns:
37
+ tuple: A tuple containing the logits and labels (logits, labels).
38
+ """
39
+ bare_sentence = sentence.replace(dep_token, "").replace(arr_token, "")
40
+ departure = get_tagged_content(sentence, dep_token)
41
+ arrival = get_tagged_content(sentence, arr_token)
42
+ tokenized_sentence = nltk.word_tokenize(bare_sentence)
43
+ labels = []
44
+ logits = []
45
+ for token in tokenized_sentence:
46
+ if token in departure:
47
+ departure_labels = [2] * len(token)
48
+ labels.extend(departure_labels)
49
+ elif token in arrival:
50
+ arrival_labels = [3] * len(token)
51
+ labels.extend(arrival_labels)
52
+ else:
53
+ default_labels = [1] * len(token)
54
+ labels.extend(default_labels)
55
+ int_chars = [ord(char) for char in token]
56
+ logits.extend(int_chars)
57
+
58
+ return (logits, labels)
requirements.txt CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ nltk==3.9.1
2
+ numpy==2.1.0