Spaces:
Build error
Build error
| from typing import Dict, Any | |
| import numpy as np | |
| import spacy | |
| from PIL import ImageFont | |
| from spacy.tokens import Doc | |
| def get_pil_text_size(text, font_size, font_name): | |
| font = ImageFont.truetype(font_name, font_size) | |
| size = font.getsize(text) | |
| return size | |
| def render_arrow( | |
| label: str, start: int, end: int, direction: str, i: int | |
| ) -> str: | |
| """Render individual arrow. | |
| label (str): Dependency label. | |
| start (int): Index of start word. | |
| end (int): Index of end word. | |
| direction (str): Arrow direction, 'left' or 'right'. | |
| i (int): Unique ID, typically arrow index. | |
| RETURNS (str): Rendered SVG markup. | |
| """ | |
| TPL_DEP_ARCS = """ | |
| <g class="displacy-arrow"> | |
| <path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="red"/> | |
| <text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px"> | |
| <textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="red" text-anchor="middle">{label}</textPath> | |
| </text> | |
| <path class="displacy-arrowhead" d="{head}" fill="red"/> | |
| </g> | |
| """ | |
| arc = get_arc(start + 10, 50, 5, end + 10) | |
| arrowhead = get_arrowhead(direction, start + 10, 50, end + 10) | |
| label_side = "right" if direction == "rtl" else "left" | |
| return TPL_DEP_ARCS.format( | |
| id=0, | |
| i=0, | |
| stroke=2, | |
| head=arrowhead, | |
| label=label, | |
| label_side=label_side, | |
| arc=arc, | |
| ) | |
| def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str: | |
| """Render individual arc. | |
| x_start (int): X-coordinate of arrow start point. | |
| y (int): Y-coordinate of arrow start and end point. | |
| y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. | |
| x_end (int): X-coordinate of arrow end point. | |
| RETURNS (str): Definition of the arc path ('d' attribute). | |
| """ | |
| template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" | |
| return template.format(x=x_start, y=y, c=y_curve, e=x_end) | |
| def get_arrowhead(direction: str, x: int, y: int, end: int) -> str: | |
| """Render individual arrow head. | |
| direction (str): Arrow direction, 'left' or 'right'. | |
| x (int): X-coordinate of arrow start point. | |
| y (int): Y-coordinate of arrow start and end point. | |
| end (int): X-coordinate of arrow end point. | |
| RETURNS (str): Definition of the arrow head path ('d' attribute). | |
| """ | |
| arrow_width = 6 | |
| if direction == "left": | |
| p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2) | |
| else: | |
| p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2) | |
| return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}" | |
| # parsed = [{'words': [{'text': 'The', 'tag': 'DET', 'lemma': None}, {'text': 'OnePlus', 'tag': 'PROPN', 'lemma': None}, {'text': '10', 'tag': 'NUM', 'lemma': None}, {'text': 'Pro', 'tag': 'PROPN', 'lemma': None}, {'text': 'is', 'tag': 'AUX', 'lemma': None}, {'text': 'the', 'tag': 'DET', 'lemma': None}, {'text': 'company', 'tag': 'NOUN', 'lemma': None}, {'text': "'s", 'tag': 'PART', 'lemma': None}, {'text': 'first', 'tag': 'ADJ', 'lemma': None}, {'text': 'flagship', 'tag': 'NOUN', 'lemma': None}, {'text': 'phone.', 'tag': 'NOUN', 'lemma': None}], 'arcs': [{'start': 0, 'end': 3, 'label': 'det', 'dir': 'left'}, {'start': 1, 'end': 3, 'label': 'nmod', 'dir': 'left'}, {'start': 1, 'end': 2, 'label': 'nummod', 'dir': 'right'}, {'start': 3, 'end': 4, 'label': 'nsubj', 'dir': 'left'}, {'start': 5, 'end': 6, 'label': 'det', 'dir': 'left'}, {'start': 6, 'end': 10, 'label': 'poss', 'dir': 'left'}, {'start': 6, 'end': 7, 'label': 'case', 'dir': 'right'}, {'start': 8, 'end': 10, 'label': 'amod', 'dir': 'left'}, {'start': 9, 'end': 10, 'label': 'compound', 'dir': 'left'}, {'start': 4, 'end': 10, 'label': 'attr', 'dir': 'right'}], 'settings': {'lang': 'en', 'direction': 'ltr'}}] | |
| def render_sentence_custom(unmatched_list: Dict): | |
| TPL_DEP_WORDS = """ | |
| <text class="displacy-token" fill="currentColor" text-anchor="start" y="{y}"> | |
| <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan> | |
| <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan> | |
| </text> | |
| """ | |
| TPL_DEP_SVG = """ | |
| <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg> | |
| """ | |
| arcs_svg = [] | |
| nlp = spacy.load('en_core_web_lg') | |
| doc = nlp(unmatched_list["sentence"]) | |
| # words = {} | |
| # unmatched_list = [parse_deps(doc)] | |
| # #print(parsed) | |
| # for i, p in enumerate(unmatched_list): | |
| # arcs = p["arcs"] | |
| # words = p["words"] | |
| # for i, a in enumerate(arcs): | |
| # #CHECK CERTAIN DEPS (ALSO ADD/CHANGE BELOW WHEN CHANGING HERE) | |
| # if a["label"] == "amod": | |
| # couples = (a["start"], a["end"]) | |
| # elif a["label"] == "pobj": | |
| # couples = (a["start"], a["end"]) | |
| # #couples = (3,5) | |
| # | |
| # x_value_counter = 10 | |
| # index_counter = 0 | |
| # svg_words = [] | |
| # coords_test = [] | |
| # for i, word in enumerate(words): | |
| # word = word["text"] | |
| # word = word + " " | |
| # pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] | |
| # svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) | |
| # if index_counter >= couples[0] and index_counter <= couples[1]: | |
| # coords_test.append(x_value_counter) | |
| # x_value_counter += 50 | |
| # index_counter += 1 | |
| # x_value_counter += pixel_x_length + 4 | |
| # for i, a in enumerate(arcs): | |
| # if a["label"] == "amod": | |
| # arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i)) | |
| # elif a["label"] == "pobj": | |
| # arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i)) | |
| # | |
| # content = "".join(svg_words) + "".join(arcs_svg) | |
| # | |
| # full_svg = TPL_DEP_SVG.format( | |
| # id=0, | |
| # width=1200, #600 | |
| # height=250, #125 | |
| # color="#00000", | |
| # bg="#ffffff", | |
| # font="Arial", | |
| # content=content, | |
| # dir="ltr", | |
| # lang="en", | |
| # ) | |
| x_value_counter = 10 | |
| index_counter = 0 | |
| svg_words = [] | |
| words = unmatched_list["sentence"].split(" ") | |
| coords_test = [] | |
| #print(unmatched_list) | |
| #print(words) | |
| #print("NOW") | |
| direction_current = "rtl" | |
| if unmatched_list["cur_word_index"] < unmatched_list["target_word_index"]: | |
| min_index = unmatched_list["cur_word_index"] | |
| max_index = unmatched_list["target_word_index"] | |
| direction_current = "left" | |
| else: | |
| max_index = unmatched_list["cur_word_index"] | |
| min_index = unmatched_list["target_word_index"] | |
| for i, token in enumerate(doc): | |
| word = str(token) | |
| word = word + " " | |
| pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] | |
| svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) | |
| if min_index <= index_counter <= max_index: | |
| coords_test.append(x_value_counter) | |
| if index_counter < max_index - 1: | |
| x_value_counter += 50 | |
| index_counter += 1 | |
| x_value_counter += pixel_x_length + 4 | |
| # TODO: DYNAMIC DIRECTION MAKING (SHOULD GIVE WITH DICT I THINK) | |
| #print(coords_test) | |
| arcs_svg.append(render_arrow(unmatched_list['dep'], coords_test[0], coords_test[-1], direction_current, i)) | |
| content = "".join(svg_words) + "".join(arcs_svg) | |
| full_svg = TPL_DEP_SVG.format( | |
| id=0, | |
| width=1200, # 600 | |
| height=75, # 125 | |
| color="#00000", | |
| bg="#ffffff", | |
| font="Arial", | |
| content=content, | |
| dir="ltr", | |
| lang="en", | |
| ) | |
| return full_svg | |
| def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: | |
| """Generate dependency parse in {'words': [], 'arcs': []} format. | |
| doc (Doc): Document do parse. | |
| RETURNS (dict): Generated dependency parse keyed by words and arcs. | |
| """ | |
| doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data"])) | |
| if not doc.has_annotation("DEP"): | |
| print("WARNING") | |
| if options.get("collapse_phrases", False): | |
| with doc.retokenize() as retokenizer: | |
| for np in list(doc.noun_chunks): | |
| attrs = { | |
| "tag": np.root.tag_, | |
| "lemma": np.root.lemma_, | |
| "ent_type": np.root.ent_type_, | |
| } | |
| retokenizer.merge(np, attrs=attrs) | |
| if options.get("collapse_punct", True): | |
| spans = [] | |
| for word in doc[:-1]: | |
| if word.is_punct or not word.nbor(1).is_punct: | |
| continue | |
| start = word.i | |
| end = word.i + 1 | |
| while end < len(doc) and doc[end].is_punct: | |
| end += 1 | |
| span = doc[start:end] | |
| spans.append((span, word.tag_, word.lemma_, word.ent_type_)) | |
| with doc.retokenize() as retokenizer: | |
| for span, tag, lemma, ent_type in spans: | |
| attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type} | |
| retokenizer.merge(span, attrs=attrs) | |
| fine_grained = options.get("fine_grained") | |
| add_lemma = options.get("add_lemma") | |
| words = [ | |
| { | |
| "text": w.text, | |
| "tag": w.tag_ if fine_grained else w.pos_, | |
| "lemma": w.lemma_ if add_lemma else None, | |
| } | |
| for w in doc | |
| ] | |
| arcs = [] | |
| for word in doc: | |
| if word.i < word.head.i: | |
| arcs.append( | |
| {"start": word.i, "end": word.head.i, "label": word.dep_, "dir": "left"} | |
| ) | |
| elif word.i > word.head.i: | |
| arcs.append( | |
| { | |
| "start": word.head.i, | |
| "end": word.i, | |
| "label": word.dep_, | |
| "dir": "right", | |
| } | |
| ) | |
| return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} | |
| def get_doc_settings(doc: Doc) -> Dict[str, Any]: | |
| return { | |
| "lang": doc.lang_, | |
| "direction": doc.vocab.writing_system.get("direction", "ltr"), | |
| } | |