Spaces:
Running
Running
from io import BytesIO | |
import streamlit as st | |
import base64 | |
from transformers import AutoModel, AutoTokenizer | |
from graphviz import Digraph | |
import json | |
def display_tree(output): | |
size = str(int(len(output))) + ',5' | |
dpi = '300' | |
format = 'svg' | |
print(size, dpi) | |
# Initialize Digraph object | |
dot = Digraph(engine='dot', format=format) | |
dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi) | |
# Add nodes and edges | |
for i,word_info in enumerate(output): | |
word = word_info['word'] # Prepare word for RTL display | |
head_idx = word_info['dep_head_idx'] | |
dep_func = word_info['dep_func'] | |
dot.node(str(i), word) | |
# Create an invisible edge from the previous word to this one to enforce order | |
if i > 0: | |
dot.edge(str(i), str(i - 1), style='invis') | |
if head_idx != -1: | |
dot.edge(str(head_idx), str(i), label=dep_func, constraint='False') | |
# Render the Digraph object | |
dot.render('syntax_tree', format=format, cleanup=True) | |
# Display the image in a scrollable container | |
st.markdown( | |
f""" | |
<div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw"> | |
<img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}" | |
style="display: block; margin: auto; max-height: 240px;"> | |
</div> | |
""", unsafe_allow_html=True) | |
#st.image('syntax_tree.' + format, use_column_width=True) | |
def display_download(disp_string): | |
to_download = BytesIO(disp_string.encode()) | |
st.download_button(label="⬇️ Download text file", | |
data=to_download, | |
file_name="parsed_output.txt", | |
mime="text/plain") | |
# Streamlit app title | |
st.title('DictaBERT-Joint Visualizer') | |
# Load Hugging Face token | |
hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets | |
# Authenticate and load model | |
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token) | |
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True) | |
model.eval() | |
# Checkbox for the compute_mst parameter | |
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True) | |
output_style = st.selectbox( | |
'Output Style: ', | |
('JSON', 'UD', 'IAHLT_UD'), index=1).lower() | |
# User input | |
sentence = st.text_input('Enter a sentence to analyze:') | |
if sentence: | |
# Display the input sentence | |
st.text(sentence) | |
# Model prediction | |
output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0] | |
if output_style == 'ud' or output_style == 'iahlt_ud': | |
ud_output = output | |
# convert to tree format of [dict(word, dep_head_idx, dep_func)] | |
tree = [] | |
for l in ud_output[2:]: | |
parts = l.split('\t') | |
if '-' in parts[0]: continue | |
tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7])) | |
display_tree(tree) | |
display_download('\n'.join(ud_output)) | |
# Construct the table as a Markdown string | |
table_md = "<div dir='rtl' style='text-align: right;'>\n\n" # Start with RTL div | |
# Add the UD header lines | |
table_md += "##" + ud_output[0] + "\n" | |
table_md += "##" + ud_output[1] + "\n" | |
# Table header | |
table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n" | |
# Table alignment | |
table_md += "| " + " | ".join(["---"]*10) + " |\n" | |
for line in ud_output[2:]: | |
# Each UD line as a table row | |
cells = line.replace('_', '\\_').replace('|', '|').replace(':', ':').split('\t') | |
table_md += "| " + " | ".join(cells) + " |\n" | |
table_md += "</div>" # Close the RTL div | |
print(table_md) | |
# Display the table using a single markdown call | |
st.markdown(table_md, unsafe_allow_html=True) | |
else: | |
# display the tree | |
tree = [w['syntax'] for w in output['tokens']] | |
display_tree(tree) | |
json_output = json.dumps(output, ensure_ascii=False, indent=2) | |
display_download(json_output) | |
# and the full json | |
st.markdown("```json\n" + json_output + "\n```") | |