Spaces:
Runtime error
Runtime error
# HF space creator starting from an sklearn model | |
from __future__ import annotations | |
import base64 | |
import glob | |
import io | |
import os | |
import pickle | |
import re | |
import shutil | |
from pathlib import Path | |
from tempfile import mkdtemp | |
import pandas as pd | |
import sklearn | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
from sklearn.base import BaseEstimator | |
import skops.io as sio | |
from skops import card, hub_utils | |
st.set_page_config(layout="wide") | |
st.title("Skops space creator for sklearn") | |
PLACEHOLDER = "[More Information Needed]" | |
PLOT_PREFIX = "__plot__:" | |
# store session state | |
if "custom_sections" not in st.session_state: | |
st.session_state.custom_sections = {} | |
# the tmp_path is used to upload the sklearn model to | |
tmp_path = Path(mkdtemp(prefix="skops-")) | |
# the hf_path is the actual repo used for init() | |
hf_path = Path(mkdtemp(prefix="skops-")) | |
# a hacky way to "persist" custom sections | |
CUSTOM_SECTIONS_CACHE_FILE = ".custom-sections.json" | |
def _clear_custom_section_cache(): | |
st.session_state.custom_sections.clear() | |
def _remove_custom_section(key): | |
section_names = list(st.session_state.custom_sections.keys()) | |
for section_name in section_names: | |
if ( | |
(section_name == key) | |
or section_name.startswith(key + "/") | |
or section_name.startswith(key + " /") | |
): | |
del st.session_state.custom_sections[section_name] | |
def _clear_repo(path): | |
for file_path in glob.glob(str(Path(path) / "*")): | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
def _write_plot(plot_name, plot_file): | |
with open(plot_name, "wb") as f: | |
f.write(plot_file) | |
def init_repo(): | |
_clear_repo(hf_path) | |
try: | |
file_name = tmp_path / "model.skops" | |
sio.dump(model, file_name) | |
reqs = [r.strip().rstrip(",") for r in requirements.splitlines()] | |
hub_utils.init( | |
model=file_name, | |
dst=hf_path, | |
task=task, | |
data=data, | |
requirements=reqs, | |
) | |
except Exception as exc: | |
print("Uh oh, something went wrong when initializing the repo:", exc) | |
def load_model(): | |
if model_file is None: | |
return | |
bytes_data = model_file.getvalue() | |
model = pickle.loads(bytes_data) | |
assert isinstance(model, BaseEstimator), "model must be an sklearn model" | |
return model | |
def load_data(): | |
if data_file is None: | |
return | |
bytes_data = io.BytesIO(data_file.getvalue()) | |
df = pd.read_csv(bytes_data) | |
return df | |
def _parse_metrics(metrics): | |
metrics_table = {} | |
for line in metrics.splitlines(): | |
line = line.strip() | |
name, _, val = line.partition("=") | |
try: | |
# try to coerce to float but don't error if it fails | |
val = float(val.strip()) | |
except ValueError: | |
pass | |
metrics_table[name.strip()] = val | |
return metrics_table | |
def _load_model_card_from_repo(repo_id: str) -> Card: | |
path = hf_hub_download(repo_id, "README.md") | |
return card.parse_modelcard(path) | |
def _create_model_card(): | |
init_repo() | |
if model_card_repo: # load existing model card | |
model_card = _load_model_card_from_repo(model_card_repo) | |
else: # create new model card | |
metadata = card.metadata_from_config(hf_path) | |
model_card = card.Card(model=model, metadata=metadata) | |
if model_description: | |
model_card.add(**{"Model description": model_description}) | |
if intended_uses: | |
model_card.add( | |
**{"Model description/Intended uses & limitations": intended_uses} | |
) | |
if metrics: | |
metrics_table = _parse_metrics(metrics) | |
model_card.add_metrics(**metrics_table) | |
if authors: | |
model_card.add(**{"Model Card Authors": authors}) | |
if contact: | |
model_card.add(**{"Model Card Contact": contact}) | |
if citation: | |
model_card.add(**{"Citation": citation}) | |
if st.session_state.custom_sections: | |
for key, val in st.session_state.custom_sections.items(): | |
if not key: | |
continue | |
if key.startswith(PLOT_PREFIX): | |
key = key[len(PLOT_PREFIX) :] # noqa | |
model_card.add_plot(**{key: val}) | |
else: | |
model_card.add(**{key: val}) | |
return model_card | |
def _process_card_for_rendering(rendered: str) -> tuple[str, str]: | |
idx = rendered[1:].index("\n---") + 1 | |
metadata = rendered[3:idx] | |
rendered = rendered[idx + 4 :] # noqa: E203 | |
# below is a hack to display the images in streamlit | |
# https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is | |
# that streamlit does not display images in markdown, so we need to replace | |
# them with html. However, we only want that in the rendered markdown, not | |
# in the card that is produced for the hub | |
def markdown_images(markdown): | |
# example image markdown: | |
# ![Test image](images/test.png "Alternate text") | |
images = re.findall( | |
r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))', | |
markdown, | |
) | |
return images | |
def img_to_bytes(img_path): | |
img_bytes = Path(img_path).read_bytes() | |
encoded = base64.b64encode(img_bytes).decode() | |
return encoded | |
def img_to_html(img_path, img_alt): | |
img_format = img_path.split(".")[-1] | |
img_html = ( | |
f'<img src="data:image/{img_format.lower()};' | |
f'base64,{img_to_bytes(img_path)}" ' | |
f'alt="{img_alt}" ' | |
'style="max-width: 100%;">' | |
) | |
return img_html | |
def markdown_insert_images(markdown): | |
images = markdown_images(markdown) | |
for image in images: | |
image_markdown = image[0] | |
image_alt = image[1] | |
image_path = image[2] | |
markdown = markdown.replace( | |
image_markdown, img_to_html(image_path, image_alt) | |
) | |
return markdown | |
rendered_with_img = markdown_insert_images(rendered) | |
return metadata, rendered_with_img | |
def display_model_card(model_card): | |
if not model_card: | |
return | |
rendered = model_card.render() | |
metadata, rendered = _process_card_for_rendering(rendered) | |
# idx = rendered[1:].index("\n---") + 1 | |
# metadata = rendered[3:idx] | |
# rendered = rendered[idx + 4 :] # noqa: E203 | |
# strip metadata | |
with st.expander("show metadata"): | |
st.text(metadata) | |
st.markdown(rendered, unsafe_allow_html=True) | |
def download_model_card(model_card): | |
if model_card is not None: | |
return model_card.render() | |
return "" | |
def add_custom_section(): | |
# this is required to "refresh" these variables... | |
global section_name, section_content | |
section_name = st.session_state.key_section_name | |
section_content = st.session_state.key_section_content | |
if not section_name or not section_content: | |
return | |
st.session_state.custom_sections[section_name] = section_content | |
def add_custom_plot(): | |
# this is required to "refresh" these variables... | |
global section_name, section_content | |
plot_name = st.session_state.key_plot_name | |
plot_file = st.session_state.key_plot_file | |
if not plot_name or not plot_file: | |
return | |
# store plot in temp repo | |
file_name = plot_file.name.replace(" ", "_") | |
file_path = str(tmp_path / file_name) | |
with open(file_path, "wb") as f: | |
f.write(plot_file.getvalue()) | |
st.session_state.custom_sections[str(PLOT_PREFIX + plot_name)] = file_path | |
with st.sidebar: | |
# This contains every element required to edit the model card | |
model = None | |
data = None | |
section_name = None | |
section_content = None | |
st.title("Model Card Editor") | |
model_file = st.file_uploader("Upload a model*", on_change=load_model) | |
data_file = st.file_uploader( | |
"Upload X data (csv)*", type=["csv"], on_change=load_data | |
) | |
task = st.selectbox( | |
label="Choose the task type*", | |
options=[ | |
"tabular-classification", | |
"tabular-regression", | |
"text-classification", | |
"text-regression", | |
], | |
on_change=init_repo, | |
) | |
requirements = st.text_area( | |
label="Requirements*", | |
value=f"scikit-learn=={sklearn.__version__}\n", | |
on_change=init_repo, | |
) | |
if model_file is not None: | |
model = load_model() | |
if data_file is not None: | |
data = load_data() | |
if model is not None and data is not None: | |
init_repo() | |
model_card_repo = st.text_input( | |
"Optional: HF repo to load model card from (e.g. 'gpt2'), " | |
"leave empty to use default skops template", | |
value="", | |
) | |
# DEFAULT SKOPS SECTIONS | |
if not model_card_repo: | |
model_description = st.text_input("Model description", value=PLACEHOLDER) | |
intended_uses = st.text_area( | |
"Intended uses & limitations", height=2, value=PLACEHOLDER | |
) | |
metrics = st.text_area("Metrics (e.g. 'accuracy = 0.95'), one metric per line") | |
authors = st.text_area( | |
"Authors", | |
value="This model card is written by following authors:\n\n" + PLACEHOLDER, | |
) | |
contact = st.text_area( | |
"Contact", | |
value="You can contact the model card authors through following channels:\n\n" | |
+ PLACEHOLDER, | |
) | |
citation = st.text_area( | |
"Citation", | |
value="Below you can find information related to citation.\n\nBibTex:\n\n```\n" | |
+ PLACEHOLDER | |
+ "\n```", | |
height=5, | |
) | |
else: | |
model_description = None | |
intended_uses = None | |
metrics = None | |
authors = None | |
contact = None | |
citation = None | |
# ADD A CUSTOM SECTIONS | |
with st.form("custom-section", clear_on_submit=True): | |
section_name = st.text_input( | |
"Section name (use '/' for subsections, e.g. 'Model description/My new" | |
" section')", | |
key="key_section_name", | |
) | |
section_content = st.text_area( | |
"Content of the new section", key="key_section_content" | |
) | |
submit_new_section = st.form_submit_button( | |
"Create new section", on_click=add_custom_section | |
) | |
# ADD A PLOT | |
with st.form("custom-plots", clear_on_submit=True): | |
plot_name = st.text_input( | |
"Section name (use '/' for subsections, e.g. 'Model description/My new" | |
" plot')", | |
key="key_plot_name", | |
) | |
plot_file = st.file_uploader("Upload a figure*", key="key_plot_file") | |
submit_new_plot = st.form_submit_button("Add plot", on_click=add_custom_plot) | |
for key in st.session_state.custom_sections: | |
if not key: | |
continue | |
if key.startswith(PLOT_PREFIX): | |
st.button( | |
f"Remove plot '{key[len(PLOT_PREFIX):]}'", | |
on_click=_remove_custom_section, | |
args=(key,), | |
) | |
else: | |
st.button( | |
f"Remove section '{key}'", on_click=_remove_custom_section, args=(key,) | |
) | |
if st.session_state.custom_sections: | |
st.button( | |
f"Remove all ({len(st.session_state.custom_sections)}) custom elements", | |
on_click=_clear_custom_section_cache, | |
) | |
model_card = None | |
if model is None: | |
st.text("*add a model to render the model card*") | |
if data is None: | |
st.text("*add data to render the model card") | |
if (model is not None) and (data is not None): | |
model_card = _create_model_card() | |
# this contains the rendered model card | |
rendered = download_model_card(model_card) | |
if rendered: | |
st.download_button(label="Download model card (markdown format)", data=rendered) | |
display_model_card(model_card) | |