# pylint: disable=no-member import gradio as gr import requests from huggingface_hub import HfApi from huggingface_hub.errors import RepositoryNotFoundError import pandas as pd import plotly.express as px from gradio_huggingfacehub_search import HuggingfaceHubSearch HF_API = HfApi() def format_repo_size(r_size: int) -> str: units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"} order = 0 while r_size >= 1024 and order < len(units) - 1: r_size /= 1024 order += 1 return f"{r_size:.2f} {units[order]}" def repo_files(r_type: str, r_id: str) -> dict: r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True) files = {} for sibling in r_info.siblings: ext = sibling.rfilename.split(".")[-1] if ext in files: files[ext]["size"] += sibling.size files[ext]["count"] += 1 else: files[ext] = {} files[ext]["size"] = sibling.size files[ext]["count"] = 1 return files def repo_size(r_type, r_id): try: r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type) except RepositoryNotFoundError: gr.Warning(f"Repository is gated, branch information for {r_id} not available.") return {} repo_sizes = {} for branch in r_refs.branches: try: response = requests.get( f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}", timeout=1000, ) response = response.json() except Exception: response = {} if response.get("error") and ( "restricted" in response.get("error") or "gated" in response.get("error") ): gr.Warning(f"Branch information for {r_id} not available.") return {} size = response.get("size") if size is not None: repo_sizes[branch.name] = size return repo_sizes def get_repo_info(r_type, r_id): try: repo_sizes = repo_size(r_type, r_id) repo_files_info = repo_files(r_type, r_id) except RepositoryNotFoundError: gr.Warning( "Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository." ) return ( gr.Row(visible=False), gr.Dataframe(visible=False), gr.Plot(visible=False), gr.Row(visible=False), gr.Dataframe(visible=False), ) rf_sizes_df = ( pd.DataFrame(repo_files_info) .T.reset_index(names="ext") .sort_values(by="size", ascending=False) ) # check if repo_sizes is just {} if not repo_sizes: r_sizes_component = gr.Dataframe(visible=False) b_block = gr.Row(visible=False) else: r_sizes_df = pd.DataFrame(repo_sizes, index=["size"]).T.reset_index( names="branch" ) r_sizes_df["formatted_size"] = r_sizes_df["size"].apply(format_repo_size) r_sizes_df.columns = ["Branch", "bytes", "Size"] r_sizes_component = gr.Dataframe( value=r_sizes_df[["Branch", "Size"]], visible=True ) b_block = gr.Row(visible=True) rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size) rf_sizes_df.columns = ["Extension", "bytes", "Count", "Size"] rf_sizes_plot = px.pie( rf_sizes_df, values="bytes", names="Extension", hover_data=["Size"], title=f"File Distribution in {r_id}", hole=0.3, ) return ( gr.Row(visible=True), gr.Dataframe( value=rf_sizes_df[["Extension", "Count", "Size"]], visible=True, ), gr.Plot(rf_sizes_plot, visible=True), b_block, r_sizes_component, ) with gr.Blocks(theme="ocean") as demo: gr.Markdown("# Repository Information") gr.Markdown( "Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's files and branches." ) with gr.Blocks(): # repo_id = gr.Textbox(label="Repository ID", placeholder="123456") repo_id = HuggingfaceHubSearch( label="Hub Repository Search (enter user, organization, or repository name to start searching)", placeholder="Search for model or dataset repositories on Huggingface", search_type=["model", "dataset"], ) repo_type = gr.Radio( choices=["model", "dataset"], label="Repository Type", value="model", ) search_button = gr.Button(value="Search") with gr.Blocks(): with gr.Row(visible=False) as results_block: with gr.Column(): gr.Markdown("## File Information") with gr.Row(): file_info = gr.Dataframe(visible=False) file_info_plot = gr.Plot(visible=False) with gr.Row(visible=False) as branch_block: with gr.Column(): gr.Markdown("## Branch Sizes") branch_sizes = gr.Dataframe(visible=False) search_button.click( get_repo_info, inputs=[repo_type, repo_id], outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes], ) demo.launch()