Spaces:
Build error
Build error
from typing import Dict | |
from PIL import ImageFont | |
TPL_DEP_WORDS = """ | |
<text class="displacy-token" fill="currentColor" text-anchor="start" y="{y}"> | |
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan> | |
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan> | |
</text> | |
""" | |
TPL_DEP_SVG = """ | |
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg> | |
""" | |
TPL_DEP_ARCS = """ | |
<g class="displacy-arrow"> | |
<path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="red"/> | |
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px"> | |
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="red" text-anchor="middle">{label}</textPath> | |
</text> | |
<path class="displacy-arrowhead" d="{head}" fill="red"/> | |
</g> | |
""" | |
def get_pil_text_size(text, font_size, font_name): | |
font = ImageFont.truetype(font_name, font_size) | |
size = font.getsize(text) | |
return size | |
def render_arrow( | |
label: str, start: int, end: int, direction: str, i: int | |
) -> str: | |
"""Render individual arrow. | |
label (str): Dependency label. | |
start (int): Index of start word. | |
end (int): Index of end word. | |
direction (str): Arrow direction, 'left' or 'right'. | |
i (int): Unique ID, typically arrow index. | |
RETURNS (str): Rendered SVG markup. | |
""" | |
arc = get_arc(start + 10, 50, 5, end + 10) | |
arrowhead = get_arrowhead(direction, start + 10, 50, end + 10) | |
label_side = "right" if direction == "rtl" else "left" | |
return TPL_DEP_ARCS.format( | |
id=0, | |
i=0, | |
stroke=2, | |
head=arrowhead, | |
label=label, | |
label_side=label_side, | |
arc=arc, | |
) | |
def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str: | |
"""Render individual arc. | |
x_start (int): X-coordinate of arrow start point. | |
y (int): Y-coordinate of arrow start and end point. | |
y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. | |
x_end (int): X-coordinate of arrow end point. | |
RETURNS (str): Definition of the arc path ('d' attribute). | |
""" | |
template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" | |
return template.format(x=x_start, y=y, c=y_curve, e=x_end) | |
def get_arrowhead(direction: str, x: int, y: int, end: int) -> str: | |
"""Render individual arrow head. | |
direction (str): Arrow direction, 'left' or 'right'. | |
x (int): X-coordinate of arrow start point. | |
y (int): Y-coordinate of arrow start and end point. | |
end (int): X-coordinate of arrow end point. | |
RETURNS (str): Definition of the arrow head path ('d' attribute). | |
""" | |
arrow_width = 6 | |
if direction == "left": | |
p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2) | |
else: | |
p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2) | |
return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}" | |
def render_sentence_custom(unmatched_list: Dict, nlp): | |
arcs_svg = [] | |
doc = nlp(unmatched_list["sentence"]) | |
x_value_counter = 10 | |
index_counter = 0 | |
svg_words = [] | |
words_under_arc = [] | |
direction_current = "rtl" | |
if unmatched_list["cur_word_index"] < unmatched_list["target_word_index"]: | |
min_index = unmatched_list["cur_word_index"] | |
max_index = unmatched_list["target_word_index"] | |
direction_current = "left" | |
else: | |
max_index = unmatched_list["cur_word_index"] | |
min_index = unmatched_list["target_word_index"] | |
for i, token in enumerate(doc): | |
word = str(token) | |
word = word + " " | |
pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] | |
svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) | |
if min_index <= index_counter <= max_index: | |
words_under_arc.append(x_value_counter) | |
if index_counter < max_index - 1: | |
x_value_counter += 50 | |
index_counter += 1 | |
x_value_counter += pixel_x_length + 4 | |
arcs_svg.append(render_arrow(unmatched_list['dep'], words_under_arc[0], words_under_arc[-1], direction_current, i)) | |
content = "".join(svg_words) + "".join(arcs_svg) | |
full_svg = TPL_DEP_SVG.format( | |
id=0, | |
width=1200, # 600 | |
height=75, # 125 | |
color="#00000", | |
bg="#ffffff", | |
font="Arial", | |
content=content, | |
dir="ltr", | |
lang="en", | |
) | |
return full_svg | |