# pylint: disable=no-member import gradio as gr import requests from huggingface_hub import HfApi from huggingface_hub.errors import RepositoryNotFoundError import pandas as pd import plotly.express as px from gradio_huggingfacehub_search import HuggingfaceHubSearch from collections import defaultdict import numpy as np HF_API = HfApi() def apply_power_scaling(sizes: list, exponent=0.2) -> list: """Apply custom power scaling to the sizes.""" return [size**exponent if size is not None else 0 for size in sizes] def count_chunks(sizes: list | int) -> list: """Count the number of chunks, which are 64KB each in size; always roundup""" if isinstance(sizes, int): return int(np.ceil(sizes / 64_000)) return [int(np.ceil(size / 64_000)) if size is not None else 0 for size in sizes] def build_hierarchy(siblings: list) -> dict: """Builds a hierarchical structure from the list of RepoSibling objects.""" hierarchy = defaultdict(dict) for sibling in siblings: path_parts = sibling.rfilename.split("/") size = sibling.lfs.size if sibling.lfs else sibling.size current_level = hierarchy for part in path_parts[:-1]: current_level = current_level.setdefault(part, {}) current_level[path_parts[-1]] = size return hierarchy def calculate_directory_sizes(hierarchy): """Recursively calculates the size of each directory as the sum of its contents.""" total_size = 0 for key, value in hierarchy.items(): if isinstance(value, dict): dir_size = calculate_directory_sizes(value) hierarchy[key] = { "__size__": dir_size, **value, } total_size += dir_size else: total_size += value return total_size def build_full_path(current_parent, key): return f"{current_parent}/{key}" if current_parent else key def flatten_hierarchy(hierarchy, root_name="Repository"): """Flatten a nested dictionary into Plotly-compatible treemap data with a defined root node.""" labels = [] parents = [] sizes = [] ids = [] # Recursively process the hierarchy def process_level(current_hierarchy, current_parent): for key, value in current_hierarchy.items(): full_path = build_full_path(current_parent, key) if isinstance(value, dict) and "__size__" in value: # Handle directories dir_size = value.pop("__size__") labels.append(key) parents.append(current_parent) sizes.append(dir_size) ids.append(full_path) process_level(value, full_path) else: # Handle files labels.append(key) parents.append(current_parent) sizes.append(value) ids.append(full_path) # Add the root node total_size = calculate_directory_sizes(hierarchy) labels.append(root_name) parents.append("") sizes.append(total_size) ids.append(root_name) # Process the hierarchy process_level(hierarchy, root_name) return labels, parents, sizes, ids def visualize_repo_treemap(r_info: dict, r_id: str) -> px.treemap: """Visualizes the repository as a treemap with directory sizes and human-readable tooltips.""" siblings = r_info.siblings hierarchy = build_hierarchy(siblings) # Calculate directory sizes calculate_directory_sizes(hierarchy) # Flatten the hierarchy for Plotly labels, parents, sizes, ids = flatten_hierarchy(hierarchy, r_id) # Scale for vix scaled_sizes = apply_power_scaling(sizes) # Format the original sizes using the helper function formatted_sizes = [ (format_repo_size(size) if size is not None else None) for size in sizes ] chunks = count_chunks(sizes) colors = scaled_sizes[:] colors[0] = -1 max_value = max(scaled_sizes) normalized_colors = [value / max_value if value > 0 else 0 for value in colors] # Define the colorscale; mimics the plasma scale colorscale = [ [0.0, "#0d0887"], [0.5, "#bd3786"], [1.0, "#f0f921"], ] # Create the treemap fig = px.treemap( names=labels, parents=parents, values=scaled_sizes, color=normalized_colors, color_continuous_scale=colorscale, title=f"{r_id} by Chunks", custom_data=[formatted_sizes, chunks], height=1000, ids=ids, ) fig.update_traces(marker={"colors": ["lightgrey"] + normalized_colors[1:]}) # Add subtitle by updating the layout fig.update_layout( title={ "text": f"{r_id} file and chunk treemap
Color represents size in bytes/chunks.", "x": 0.5, "xanchor": "center", }, coloraxis_showscale=False, ) # Customize the hover template fig.update_traces( hovertemplate=( "%{label}
" "Size: %{customdata[0]}
" "# of Chunks: %{customdata[1]}" ) ) fig.update_traces(root_color="lightgrey") return fig def format_repo_size(r_size: int) -> str: """ Convert a repository size in bytes to a human-readable string with appropriate units. Args: r_size (int): The size of the repository in bytes. Returns: str: The formatted size string with appropriate units (B, KB, MB, GB, TB, PB). """ units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"} order = 0 while r_size >= 1024 and order < len(units) - 1: r_size /= 1024 order += 1 return f"{r_size:.2f} {units[order]}" def repo_files(r_type: str, r_id: str) -> dict: r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True) fig = visualize_repo_treemap(r_info, r_id) files = {} for sibling in r_info.siblings: ext = sibling.rfilename.split(".")[-1] if ext in files: files[ext]["size"] += sibling.size files[ext]["chunks"] += count_chunks(sibling.size) files[ext]["count"] += 1 else: files[ext] = {} files[ext]["size"] = sibling.size files[ext]["chunks"] = count_chunks(sibling.size) files[ext]["count"] = 1 return files, fig def repo_size(r_type, r_id): try: r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type) except RepositoryNotFoundError: gr.Warning(f"Repository is gated, branch information for {r_id} not available.") return {} repo_sizes = {} for branch in r_refs.branches: try: response = requests.get( f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}", timeout=1000, ) response = response.json() except Exception: response = {} if response.get("error") and ( "restricted" in response.get("error") or "gated" in response.get("error") ): gr.Warning(f"Branch information for {r_id} not available.") return {} size = response.get("size") if size is not None: repo_sizes[branch.name] = { "size_in_bytes": size, "size_in_chunks": count_chunks(size), } return repo_sizes def get_repo_info(r_type, r_id): try: repo_sizes = repo_size(r_type, r_id) repo_files_info, treemap_fig = repo_files(r_type, r_id) except RepositoryNotFoundError: gr.Warning( "Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository." ) return ( gr.Row(visible=False), gr.Dataframe(visible=False), gr.Plot(visible=False), gr.Row(visible=False), gr.Dataframe(visible=False), ) # check if repo_sizes is just {} if not repo_sizes: r_sizes_component = gr.Dataframe(visible=False) b_block = gr.Row(visible=False) else: r_sizes_df = pd.DataFrame(repo_sizes).T.reset_index(names="branch") r_sizes_df["formatted_size"] = r_sizes_df["size_in_bytes"].apply( format_repo_size ) r_sizes_df.columns = ["Branch", "size_in_bytes", "Chunks", "Size"] r_sizes_component = gr.Dataframe( value=r_sizes_df[["Branch", "Size", "Chunks"]], visible=True ) b_block = gr.Row(visible=True) rf_sizes_df = ( pd.DataFrame(repo_files_info) .T.reset_index(names="ext") .sort_values(by="size", ascending=False) ) rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size) rf_sizes_df.columns = ["Extension", "bytes", "Chunks", "Count", "Size"] return ( gr.Row(visible=True), gr.Dataframe( value=rf_sizes_df[["Extension", "Count", "Size", "Chunks"]], visible=True, ), # gr.Plot(rf_sizes_plot, visible=True), gr.Plot(treemap_fig, visible=True), b_block, r_sizes_component, ) with gr.Blocks(theme="ocean") as demo: gr.Markdown("# Chunking Repos") gr.Markdown( "Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's contents including the [number of chunks each file might be split into with Xet backed storage](https://huggingface.co/blog/from-files-to-chunks)." ) with gr.Blocks(): # repo_id = gr.Textbox(label="Repository ID", placeholder="123456") repo_id = HuggingfaceHubSearch( label="Hub Repository Search (enter user, organization, or repository name to start searching)", placeholder="Search for model or dataset repositories on Huggingface", search_type=["model", "dataset"], ) repo_type = gr.Radio( choices=["model", "dataset"], label="Repository Type", value="model", ) search_button = gr.Button(value="Search") with gr.Blocks(): with gr.Row(visible=False) as results_block: with gr.Column(): gr.Markdown("## Repo Info") gr.Markdown( "Hover over any file or directory to see it's size in bytes and total number of chunks required to store it in Xet storage." ) file_info_plot = gr.Plot(visible=False) with gr.Row(visible=False) as branch_block: with gr.Column(): gr.Markdown("### Branch Sizes") gr.Markdown( "The size of each branch in the repository and how many chunks it might need (assuming no dedupe)." ) branch_sizes = gr.Dataframe(visible=False) with gr.Row(): with gr.Column(): gr.Markdown("### File Sizes") gr.Markdown( "The cumulative size of each filetype in the repository (in the `main` branch) and how many chunks they might need (assuming no dedupe)." ) file_info = gr.Dataframe(visible=False) # file_info_plot = gr.Plot(visible=False) search_button.click( get_repo_info, inputs=[repo_type, repo_id], outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes], ) demo.launch()