joint-demo / app.py
Shaltiel's picture
Reversed arrow direction
2cd4845
raw
history blame
4.48 kB
from io import BytesIO
import streamlit as st
import base64
from transformers import AutoModel, AutoTokenizer
from graphviz import Digraph
import json
def display_tree(output):
size = str(int(len(output))) + ',5'
dpi = '300'
format = 'svg'
print(size, dpi)
# Initialize Digraph object
dot = Digraph(engine='dot', format=format)
dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi)
# Add nodes and edges
for i,word_info in enumerate(output):
word = word_info['word'] # Prepare word for RTL display
head_idx = word_info['dep_head_idx']
dep_func = word_info['dep_func']
dot.node(str(i), word)
# Create an invisible edge from the previous word to this one to enforce order
if i > 0:
dot.edge(str(i), str(i - 1), style='invis')
if head_idx != -1:
dot.edge(str(head_idx), str(i), label=dep_func, constraint='False')
# Render the Digraph object
dot.render('syntax_tree', format=format, cleanup=True)
# Display the image in a scrollable container
st.markdown(
f"""
<div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw">
<img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}"
style="display: block; margin: auto; max-height: 240px;">
</div>
""", unsafe_allow_html=True)
#st.image('syntax_tree.' + format, use_column_width=True)
def display_download(disp_string):
to_download = BytesIO(disp_string.encode())
st.download_button(label="⬇️ Download text file",
data=to_download,
file_name="parsed_output.txt",
mime="text/plain")
# Streamlit app title
st.title('DictaBERT-Joint Visualizer')
# Load Hugging Face token
hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets
# Authenticate and load model
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token)
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True)
model.eval()
# Checkbox for the compute_mst parameter
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True)
output_style = st.selectbox(
'Output Style: ',
('JSON', 'UD', 'IAHLT_UD'), index=1).lower()
# User input
sentence = st.text_input('Enter a sentence to analyze:')
if sentence:
# Display the input sentence
st.text(sentence)
# Model prediction
output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0]
if output_style == 'ud' or output_style == 'iahlt_ud':
ud_output = output
# convert to tree format of [dict(word, dep_head_idx, dep_func)]
tree = []
for l in ud_output[2:]:
parts = l.split('\t')
if '-' in parts[0]: continue
tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7]))
display_tree(tree)
display_download('\n'.join(ud_output))
# Construct the table as a Markdown string
table_md = "<div dir='rtl' style='text-align: right;'>\n\n" # Start with RTL div
# Add the UD header lines
table_md += "##" + ud_output[0] + "\n"
table_md += "##" + ud_output[1] + "\n"
# Table header
table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n"
# Table alignment
table_md += "| " + " | ".join(["---"]*10) + " |\n"
for line in ud_output[2:]:
# Each UD line as a table row
cells = line.replace('_', '\\_').replace('|', '&#124;').replace(':', '&colon;').split('\t')
table_md += "| " + " | ".join(cells) + " |\n"
table_md += "</div>" # Close the RTL div
print(table_md)
# Display the table using a single markdown call
st.markdown(table_md, unsafe_allow_html=True)
else:
# display the tree
tree = [w['syntax'] for w in output['tokens']]
display_tree(tree)
json_output = json.dumps(output, ensure_ascii=False, indent=2)
display_download(json_output)
# and the full json
st.markdown("```json\n" + json_output + "\n```")