File size: 3,092 Bytes
9777a0a
 
3cd2f54
 
9777a0a
 
3cd2f54
9777a0a
 
3cd2f54
 
 
 
 
9777a0a
3cd2f54
 
9777a0a
 
 
 
3cd2f54
9777a0a
 
 
 
3cd2f54
022199d
3cd2f54
9777a0a
 
 
 
3cd2f54
9777a0a
3cd2f54
 
 
 
 
 
 
 
 
 
 
9777a0a
3cd2f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9777a0a
 
 
 
 
3cd2f54
9777a0a
 
 
 
3cd2f54
 
9777a0a
 
 
 
3cd2f54
 
 
 
9777a0a
 
 
3cd2f54
 
9777a0a
3cd2f54
 
9777a0a
 
 
3cd2f54
 
 
9777a0a
 
3cd2f54
9777a0a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import marimo

__generated_with = "0.9.14"
app = marimo.App(width="medium")


@app.cell(hide_code=True)
def __():
    import marimo as mo
    import duckdb
    import pandas
    import numpy
    import altair as alt
    import plotly.express as px

    mo.md("# 🤗 Hub Model Tree")
    return alt, duckdb, mo, numpy, pandas, px


@app.cell(hide_code=True)
def __(mo):
    mo.md("""This is powered by the [Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) dataset which you can query via the [SQL Console](https://huggingface.co/datasets/cfahlgren1/hub-stats?sql_console=true). The model tree metric is where a model tags a parent model as a `base_model`. The `hub-stats` dataset gets updated daily. Try it out by putting an organization or model author in search box and hit enter.""")
    return


@app.cell
def __(duckdb):
    duckdb.sql("CREATE VIEW models as SELECT * FROM 'hf://datasets/cfahlgren1/hub-stats/models.parquet'")
    return (models,)


@app.cell(hide_code=True)
def __(mo):
    author_input = mo.ui.text(placeholder="Search...", label="Author")

    ctes = """
    WITH author_models AS (
      SELECT id 
      FROM models
      WHERE author = '{}'
    ),
    model_tags AS (
      SELECT 
        id,
        UNNEST(tags) AS tag
      FROM models
    )
    """
    def get_model_children_counts(author: str) -> str:
        return f"""
        {ctes.format(author)}
        SELECT 
          am.id as parent_model_id,
          COUNT(DISTINCT m.id) as num_direct_children
          FROM author_models am
        INNER JOIN model_tags m
          ON m.tag = 'base_model:' || am.id
        GROUP BY am.id
        ORDER BY num_direct_children DESC;
        """

    def get_total_childen_count(author: str) -> str:
        return f"""
        {ctes.format(author)}
        SELECT 
          COUNT(DISTINCT m.id) as num_direct_children
        FROM author_models am
        LEFT JOIN model_tags m
          ON m.tag = 'base_model:' || am.id
        """
    return (
        author_input,
        ctes,
        get_model_children_counts,
        get_total_childen_count,
    )


@app.cell
def __(mo):
    mo.md("## Search by Author")
    return


@app.cell(hide_code=True)
def __(author_input, mo):
    mo.vstack([author_input, mo.md("_ex: meta-llama, google, mistralai, Qwen_")])
    return


@app.cell(hide_code=True)
def __(author_input, duckdb, get_total_childen_count, mo):
    result = duckdb.sql(get_total_childen_count(author_input.value)).fetchall()
    mo.vstack([mo.md("### Direct Child Models"), mo.md(f"_The number of models that have tagged a {author_input.value} model as a `base_model`_"), mo.stat(result[0][0])])
    return (result,)


@app.cell(hide_code=True)
def __(author_input, duckdb, get_model_children_counts):
    df = duckdb.sql(get_model_children_counts(author_input.value)).fetchdf()

    df
    return (df,)


@app.cell(hide_code=True)
def __(df, mo, px):
    _plot = px.bar(
        df, x="parent_model_id", y="num_direct_children", log_y=True
    )

    mo.ui.plotly(_plot)
    return


if __name__ == "__main__":
    app.run()