gaverfraxz commited on
Commit
ced6e66
1 Parent(s): 43c47ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+ import io
8
+
9
+
10
+ def calculate_weight_diff(base_weight, chat_weight):
11
+ return torch.abs(base_weight - chat_weight).mean().item()
12
+
13
+
14
+ def calculate_layer_diffs(base_model, chat_model):
15
+ layer_diffs = []
16
+ for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers):
17
+ layer_diff = {
18
+ 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
19
+ 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
20
+ 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
21
+ 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
22
+ 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
23
+ 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
24
+ }
25
+ layer_diffs.append(layer_diff)
26
+ return layer_diffs
27
+
28
+
29
+ def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
30
+ num_layers = len(layer_diffs)
31
+ num_components = len(layer_diffs[0])
32
+
33
+ fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
34
+ fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
35
+
36
+ for i, component in enumerate(layer_diffs[0].keys()):
37
+ component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
38
+ sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8})
39
+ axs[i].set_title(component)
40
+ axs[i].set_xlabel("Layer")
41
+ axs[i].set_ylabel("Difference")
42
+ axs[i].set_xticks([])
43
+ axs[i].set_yticks(range(num_layers))
44
+ axs[i].set_yticklabels(range(num_layers))
45
+ axs[i].invert_yaxis()
46
+
47
+ plt.tight_layout()
48
+ return fig
49
+
50
+
51
+ def main():
52
+ st.set_page_config(
53
+ page_title="Model Weight Comparator",
54
+ layout="wide",
55
+ initial_sidebar_state="expanded"
56
+ )
57
+
58
+ st.title("Language Model Weight Comparator")
59
+
60
+ # Config sidebar for input parameters
61
+ with st.sidebar:
62
+ st.header("Configuration")
63
+
64
+ base_model_name = st.text_input(
65
+ "Base Model Name",
66
+ value="meta-llama/Meta-Llama-3-70B-Instruct",
67
+ help="Enter the name of the base model from Hugging Face"
68
+ )
69
+
70
+ chat_model_name = st.text_input(
71
+ "Chat Model Name",
72
+ value="mattshumer/Reflection-Llama-3.1-70B",
73
+ help="Enter the name of the chat model from Hugging Face"
74
+ )
75
+
76
+ if st.button("Compare Models"):
77
+
78
+ if not base_model_name or not chat_model_name:
79
+ st.error("Please enter both model names")
80
+ return
81
+
82
+ try:
83
+ st.info("Loading models... This might take some time.")
84
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
85
+ chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
86
+
87
+ st.info("Calculating weight differences...")
88
+ layer_diffs = calculate_layer_diffs(base_model, chat_model)
89
+
90
+ st.info("Generating visualization...")
91
+ fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
92
+ st.pyplot(fig)
93
+
94
+ # visualization
95
+ buf = io.BytesIO()
96
+ fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
97
+ buf.seek(0)
98
+ st.download_button(
99
+ label="Download Visualization",
100
+ data=buf,
101
+ file_name="model_comparison.png",
102
+ mime="image/png"
103
+ )
104
+
105
+ except Exception as e:
106
+ st.error(f"An error occurred: {str(e)}")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()