|
import streamlit as st |
|
|
|
|
|
quantization_bit_sizes = { |
|
'float32': 32, |
|
'float16': 16, |
|
'Q2_K': 2, |
|
'Q3_K_L': 3, |
|
'Q3_K_M': 3, |
|
'Q3_K_S': 3, |
|
'Q4_0': 4, |
|
'Q4_1': 4, |
|
'Q4_K_M': 4, |
|
'Q4_K_S': 4, |
|
'Q5_0': 5, |
|
'Q5_1': 5, |
|
'Q5_K_M': 5, |
|
'Q5_K_S': 5, |
|
'Q6_K': 6, |
|
'Q8_0': 8 |
|
} |
|
|
|
|
|
precision_options = { |
|
'full': 4, |
|
'mixed': 6, |
|
'half': 2 |
|
} |
|
|
|
|
|
st.title("Memory Usage Calculator for Large Language Models") |
|
|
|
|
|
|
|
|
|
def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision): |
|
|
|
byte_size = quantization_bit_sizes[data_type] / 8 |
|
|
|
|
|
memory_params = parameter_count * byte_size |
|
|
|
|
|
activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision) |
|
|
|
|
|
total_memory_usage = memory_params + activations |
|
|
|
|
|
total_memory_usage_gb = total_memory_usage / (1024 ** 3) |
|
|
|
return total_memory_usage_gb |
|
|
|
def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision): |
|
|
|
hidden_dimensions = int(parameter_count ** 0.5) |
|
|
|
|
|
activations_per_layer = context_length * batch_size * hidden_dimensions * (34 + ((5 * attention_heads * context_length) / hidden_dimensions)) |
|
activations = layers * activations_per_layer / 2 |
|
|
|
|
|
bytes_per_param = precision_options[precision] / 8 |
|
total_activations = bytes_per_param * activations |
|
|
|
return total_activations |
|
|
|
|
|
parameter_count = st.number_input("Parameter Count (in billions)", value=1, step=1) * 1e9 |
|
layers = st.number_input("Number of Layers", value=32, step=1) |
|
attention_heads = st.number_input("Number of Attention Heads", value=32, step=1) |
|
context_length = st.number_input("Context Length (number of tokens)", value=512, step=1) |
|
data_type = st.selectbox("Data Type", options=list(quantization_bit_sizes.keys())) |
|
batch_size = st.number_input("Batch Size", value=1, step=1) |
|
vocab_size = st.number_input("Vocabulary Size", value=30000, step=1000) |
|
precision = st.selectbox("Precision", options=list(precision_options.keys())) |
|
|
|
|
|
if st.button("Calculate Memory Usage"): |
|
memory_usage = calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision) |
|
st.write(f"Estimated Memory Usage for Inference: {memory_usage:.2f} GB") |
|
|
|
|