mlabonne commited on
Commit
26cdd43
β€’
1 Parent(s): 71e899e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from huggingface_hub import ModelCard, HfApi
4
+ import requests
5
+ import networkx as nx
6
+ import matplotlib.pyplot as plt
7
+ from collections import defaultdict
8
+ from networkx.drawing.nx_pydot import graphviz_layout
9
+ from io import BytesIO
10
+
11
+
12
+ def get_model_names_from_yaml(url):
13
+ """Get a list of parent model names from the yaml file."""
14
+ model_tags = []
15
+ response = requests.get(url)
16
+ if response.status_code == 200:
17
+ model_tags.extend([item for item in response.content if '/' in str(item)])
18
+ return model_tags
19
+
20
+
21
+ def get_license_color(model):
22
+ """Get the color of the model based on its license."""
23
+ try:
24
+ card = ModelCard.load(model)
25
+ license = card.data.to_dict()['license'].lower()
26
+ # Define permissive licenses
27
+ permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail'] # Add more as needed
28
+ # Check license type
29
+ if any(perm_license in license for perm_license in permissive_licenses):
30
+ return 'lightgreen' # Permissive licenses
31
+ else:
32
+ return 'lightcoral' # Noncommercial or other licenses
33
+ except Exception as e:
34
+ print(f"Error retrieving license for {model}: {e}")
35
+ return 'lightgray'
36
+
37
+
38
+ def get_model_names(model, genealogy, found_models=None):
39
+ """Get a list of parent model names from the model id."""
40
+ model_tags = []
41
+
42
+ if found_models is None:
43
+ found_models = []
44
+
45
+ try:
46
+ card = ModelCard.load(model)
47
+ card_dict = card.data.to_dict() # Convert the ModelCard object to a dictionary
48
+ license = card_dict['license']
49
+
50
+ # Check the base_model in metadata
51
+ if 'base_model' in card_dict:
52
+ model_tags = card_dict['base_model']
53
+
54
+ # Check the tags in metadata
55
+ if 'tags' in card_dict and not model_tags:
56
+ tags = card_dict['tags']
57
+ model_tags = [model_name for model_name in tags if '/' in model_name]
58
+
59
+ # Check for merge.yml and mergekit_config.yml if no model_tags found in the tags
60
+ if not model_tags:
61
+ model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/merge.yml"))
62
+ if not model_tags:
63
+ model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/mergekit_config.yml"))
64
+
65
+ # Convert to a list if tags is not None or empty, else set to an empty list
66
+ if not isinstance(model_tags, list):
67
+ model_tags = [model_tags] if model_tags else []
68
+
69
+ # Add found model names to the list
70
+ found_models.extend(model_tags)
71
+
72
+ # Record the genealogy
73
+ for model_tag in model_tags:
74
+ genealogy[model_tag].append(model)
75
+
76
+ # Recursively check for more models
77
+ for model_tag in model_tags:
78
+ get_model_names(model_tag, genealogy, found_models)
79
+
80
+ except Exception as e:
81
+ print(f"Could not find model names for {model}: {e}")
82
+
83
+ return found_models
84
+
85
+
86
+ def find_root_nodes(G):
87
+ """ Find all nodes in the graph with no predecessors """
88
+ return [n for n, d in G.in_degree() if d == 0]
89
+
90
+
91
+ def max_width_of_tree(G):
92
+ """ Calculate the maximum width of the tree """
93
+ max_width = 0
94
+ for root in find_root_nodes(G):
95
+ width_at_depth = calculate_width_at_depth(G, root)
96
+ local_max_width = max(width_at_depth.values())
97
+ max_width = max(max_width, local_max_width)
98
+ return max_width
99
+
100
+
101
+ def calculate_width_at_depth(G, root):
102
+ """ Calculate width at each depth starting from a given root """
103
+ depth_count = defaultdict(int)
104
+ queue = [(root, 0)]
105
+ while queue:
106
+ node, depth = queue.pop(0)
107
+ depth_count[depth] += 1
108
+ for child in G.successors(node):
109
+ queue.append((child, depth + 1))
110
+ return depth_count
111
+
112
+
113
+ def create_family_tree(start_model):
114
+ genealogy = defaultdict(list)
115
+ get_model_names(start_model, genealogy) # Assuming this populates the genealogy
116
+
117
+ # Create a directed graph
118
+ G = nx.DiGraph()
119
+
120
+ # Add nodes and edges to the graph
121
+ for parent, children in genealogy.items():
122
+ for child in children:
123
+ G.add_edge(parent, child)
124
+
125
+ # Get max depth
126
+ max_depth = nx.dag_longest_path_length(G) + 1
127
+
128
+ # Get max width
129
+ max_width = max_width_of_tree(G) + 1
130
+
131
+ # Estimate plot size
132
+ height = max(8, 1.6 * max_depth)
133
+ width = max(8, 6 * max_width)
134
+
135
+ # Set Graphviz layout attributes for a bottom-up tree
136
+ plt.figure(figsize=(width, height))
137
+ pos = graphviz_layout(G, prog="dot")
138
+
139
+ # Determine node colors based on license
140
+ node_colors = [get_license_color(node) for node in G.nodes()]
141
+ clear_output()
142
+
143
+ # Create a label mapping with line breaks
144
+ labels = {node: node.replace("/", "\n") for node in G.nodes()}
145
+
146
+ # Draw the graph
147
+ nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
148
+
149
+ # Create a legend for the colors
150
+ legend_elements = [
151
+ Patch(facecolor='lightgreen', label='Permissive'),
152
+ Patch(facecolor='lightcoral', label='Noncommercial'),
153
+ Patch(facecolor='lightgray', label='Unknown')
154
+ ]
155
+ plt.legend(handles=legend_elements, loc='upper left')
156
+
157
+ plt.title(f"{start_model}'s Family Tree", fontsize=20)
158
+
159
+ # Instead of plt.show(), capture the plot as an image in memory
160
+ img_buffer = BytesIO()
161
+ plt.savefig(img_buffer, format='png')
162
+ plt.close()
163
+ img_buffer.seek(0) # Rewind the buffer to the beginning
164
+
165
+ return img_buffer.getvalue()
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("# 🌳 Merge Family Tree")
169
+ model_id = gr.Textbox(label="Model ID", value="mlabonne/NeuralBeagle-7B")
170
+ btn = gr.Button("Create tree")
171
+ btn.click(fn=create_family_tree, inputs=model_id, outputs=out)
172
+ out = gr.Image()
173
+
174
+ demo.launch()