# pylint: disable=no-member
import gradio as gr
import requests
from huggingface_hub import HfApi
from huggingface_hub.errors import RepositoryNotFoundError
import pandas as pd
import plotly.express as px
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from collections import defaultdict
import numpy as np
HF_API = HfApi()
def apply_power_scaling(sizes: list, exponent=0.2) -> list:
"""Apply custom power scaling to the sizes."""
return [size**exponent if size is not None else 0 for size in sizes]
def count_chunks(sizes: list | int) -> list:
"""Count the number of chunks, which are 64KB each in size; always roundup"""
if isinstance(sizes, int):
return int(np.ceil(sizes / 64_000))
return [int(np.ceil(size / 64_000)) if size is not None else 0 for size in sizes]
def build_hierarchy(siblings: list) -> dict:
"""Builds a hierarchical structure from the list of RepoSibling objects."""
hierarchy = defaultdict(dict)
for sibling in siblings:
path_parts = sibling.rfilename.split("/")
size = sibling.lfs.size if sibling.lfs else sibling.size
current_level = hierarchy
for part in path_parts[:-1]:
current_level = current_level.setdefault(part, {})
current_level[path_parts[-1]] = size
return hierarchy
def calculate_directory_sizes(hierarchy):
"""Recursively calculates the size of each directory as the sum of its contents."""
total_size = 0
for key, value in hierarchy.items():
if isinstance(value, dict):
dir_size = calculate_directory_sizes(value)
hierarchy[key] = {
"__size__": dir_size,
**value,
}
total_size += dir_size
else:
total_size += value
return total_size
def build_full_path(current_parent, key):
return f"{current_parent}/{key}" if current_parent else key
def flatten_hierarchy(hierarchy, root_name="Repository"):
"""Flatten a nested dictionary into Plotly-compatible treemap data with a defined root node."""
labels = []
parents = []
sizes = []
ids = []
# Recursively process the hierarchy
def process_level(current_hierarchy, current_parent):
for key, value in current_hierarchy.items():
full_path = build_full_path(current_parent, key)
if isinstance(value, dict) and "__size__" in value:
# Handle directories
dir_size = value.pop("__size__")
labels.append(key)
parents.append(current_parent)
sizes.append(dir_size)
ids.append(full_path)
process_level(value, full_path)
else:
# Handle files
labels.append(key)
parents.append(current_parent)
sizes.append(value)
ids.append(full_path)
# Add the root node
total_size = calculate_directory_sizes(hierarchy)
labels.append(root_name)
parents.append("")
sizes.append(total_size)
ids.append(root_name)
# Process the hierarchy
process_level(hierarchy, root_name)
return labels, parents, sizes, ids
def visualize_repo_treemap(r_info: dict, r_id: str) -> px.treemap:
"""Visualizes the repository as a treemap with directory sizes and human-readable tooltips."""
siblings = r_info.siblings
hierarchy = build_hierarchy(siblings)
# Calculate directory sizes
calculate_directory_sizes(hierarchy)
# Flatten the hierarchy for Plotly
labels, parents, sizes, ids = flatten_hierarchy(hierarchy, r_id)
# Scale for vix
scaled_sizes = apply_power_scaling(sizes)
# Format the original sizes using the helper function
formatted_sizes = [
(format_repo_size(size) if size is not None else None) for size in sizes
]
chunks = count_chunks(sizes)
colors = scaled_sizes[:]
colors[0] = -1
max_value = max(scaled_sizes)
normalized_colors = [value / max_value if value > 0 else 0 for value in colors]
# Define the colorscale; mimics the plasma scale
colorscale = [
[0.0, "#0d0887"],
[0.5, "#bd3786"],
[1.0, "#f0f921"],
]
# Create the treemap
fig = px.treemap(
names=labels,
parents=parents,
values=scaled_sizes,
color=normalized_colors,
color_continuous_scale=colorscale,
title=f"{r_id} by Chunks",
custom_data=[formatted_sizes, chunks],
height=1000,
ids=ids,
)
fig.update_traces(marker={"colors": ["lightgrey"] + normalized_colors[1:]})
# Add subtitle by updating the layout
fig.update_layout(
title={
"text": f"{r_id} file and chunk treemap
Color represents size in bytes/chunks.",
"x": 0.5,
"xanchor": "center",
},
coloraxis_showscale=False,
)
# Customize the hover template
fig.update_traces(
hovertemplate=(
"%{label}
"
"Size: %{customdata[0]}
"
"# of Chunks: %{customdata[1]}"
)
)
fig.update_traces(root_color="lightgrey")
return fig
def format_repo_size(r_size: int) -> str:
"""
Convert a repository size in bytes to a human-readable string with appropriate units.
Args:
r_size (int): The size of the repository in bytes.
Returns:
str: The formatted size string with appropriate units (B, KB, MB, GB, TB, PB).
"""
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"}
order = 0
while r_size >= 1024 and order < len(units) - 1:
r_size /= 1024
order += 1
return f"{r_size:.2f} {units[order]}"
def repo_files(r_type: str, r_id: str) -> dict:
r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True)
fig = visualize_repo_treemap(r_info, r_id)
files = {}
for sibling in r_info.siblings:
ext = sibling.rfilename.split(".")[-1]
if ext in files:
files[ext]["size"] += sibling.size
files[ext]["chunks"] += count_chunks(sibling.size)
files[ext]["count"] += 1
else:
files[ext] = {}
files[ext]["size"] = sibling.size
files[ext]["chunks"] = count_chunks(sibling.size)
files[ext]["count"] = 1
return files, fig
def repo_size(r_type, r_id):
try:
r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type)
except RepositoryNotFoundError:
gr.Warning(f"Repository is gated, branch information for {r_id} not available.")
return {}
repo_sizes = {}
for branch in r_refs.branches:
try:
response = requests.get(
f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}",
timeout=1000,
)
response = response.json()
except Exception:
response = {}
if response.get("error") and (
"restricted" in response.get("error") or "gated" in response.get("error")
):
gr.Warning(f"Branch information for {r_id} not available.")
return {}
size = response.get("size")
if size is not None:
repo_sizes[branch.name] = {
"size_in_bytes": size,
"size_in_chunks": count_chunks(size),
}
return repo_sizes
def get_repo_info(r_type, r_id):
try:
repo_sizes = repo_size(r_type, r_id)
repo_files_info, treemap_fig = repo_files(r_type, r_id)
except RepositoryNotFoundError:
gr.Warning(
"Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository."
)
return (
gr.Row(visible=False),
gr.Dataframe(visible=False),
gr.Plot(visible=False),
gr.Row(visible=False),
gr.Dataframe(visible=False),
)
# check if repo_sizes is just {}
if not repo_sizes:
r_sizes_component = gr.Dataframe(visible=False)
b_block = gr.Row(visible=False)
else:
r_sizes_df = pd.DataFrame(repo_sizes).T.reset_index(names="branch")
r_sizes_df["formatted_size"] = r_sizes_df["size_in_bytes"].apply(
format_repo_size
)
r_sizes_df.columns = ["Branch", "size_in_bytes", "Chunks", "Size"]
r_sizes_component = gr.Dataframe(
value=r_sizes_df[["Branch", "Size", "Chunks"]], visible=True
)
b_block = gr.Row(visible=True)
rf_sizes_df = (
pd.DataFrame(repo_files_info)
.T.reset_index(names="ext")
.sort_values(by="size", ascending=False)
)
rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size)
rf_sizes_df.columns = ["Extension", "bytes", "Chunks", "Count", "Size"]
return (
gr.Row(visible=True),
gr.Dataframe(
value=rf_sizes_df[["Extension", "Count", "Size", "Chunks"]],
visible=True,
),
# gr.Plot(rf_sizes_plot, visible=True),
gr.Plot(treemap_fig, visible=True),
b_block,
r_sizes_component,
)
with gr.Blocks(theme="ocean") as demo:
gr.Markdown("# Chunking Repos")
gr.Markdown(
"Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's contents including the [number of chunks each file might be split into with Xet backed storage](https://huggingface.co/blog/from-files-to-chunks)."
)
with gr.Blocks():
# repo_id = gr.Textbox(label="Repository ID", placeholder="123456")
repo_id = HuggingfaceHubSearch(
label="Hub Repository Search (enter user, organization, or repository name to start searching)",
placeholder="Search for model or dataset repositories on Huggingface",
search_type=["model", "dataset"],
)
repo_type = gr.Radio(
choices=["model", "dataset"],
label="Repository Type",
value="model",
)
search_button = gr.Button(value="Search")
with gr.Blocks():
with gr.Row(visible=False) as results_block:
with gr.Column():
gr.Markdown("## Repo Info")
gr.Markdown(
"Hover over any file or directory to see it's size in bytes and total number of chunks required to store it in Xet storage."
)
file_info_plot = gr.Plot(visible=False)
with gr.Row(visible=False) as branch_block:
with gr.Column():
gr.Markdown("### Branch Sizes")
gr.Markdown(
"The size of each branch in the repository and how many chunks it might need (assuming no dedupe)."
)
branch_sizes = gr.Dataframe(visible=False)
with gr.Row():
with gr.Column():
gr.Markdown("### File Sizes")
gr.Markdown(
"The cumulative size of each filetype in the repository (in the `main` branch) and how many chunks they might need (assuming no dedupe)."
)
file_info = gr.Dataframe(visible=False)
# file_info_plot = gr.Plot(visible=False)
search_button.click(
get_repo_info,
inputs=[repo_type, repo_id],
outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes],
)
demo.launch()