henryholloway's picture
UI Updates
9e0a736
import streamlit as st
# Define bit sizes for different quantization options
quantization_bit_sizes = {
'float32': 32,
'float16': 16,
'Q2_K': 2,
'Q3_K_L': 3,
'Q3_K_M': 3,
'Q3_K_S': 3,
'Q4_0': 4,
'Q4_1': 4,
'Q4_K_M': 4,
'Q4_K_S': 4,
'Q5_0': 5,
'Q5_1': 5,
'Q5_K_M': 5,
'Q5_K_S': 5,
'Q6_K': 6,
'Q8_0': 8
}
# Define precision options
precision_options = {
'full': 4,
'mixed': 6,
'half': 2
}
# Streamlit app
st.title("Memory Usage Calculator for Large Language Models")
# Taken from "Reducing Activation Recomputation in Large Transformer Models" https://arxiv.org/abs/2205.05198
def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision):
# Convert bit size to byte size
byte_size = quantization_bit_sizes[data_type] / 8
# Memory usage for model parameters
memory_params = parameter_count * byte_size
# Memory usage for context (activations)
activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision)
# Total memory usage
total_memory_usage = memory_params + activations
# Convert bytes to gigabytes
total_memory_usage_gb = total_memory_usage / (1024 ** 3)
return total_memory_usage_gb
def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision):
# Assuming square root relationship for hidden size
hidden_dimensions = int(parameter_count ** 0.5)
# Calculate activations based on the formula from the paper
activations_per_layer = context_length * batch_size * hidden_dimensions * (34 + ((5 * attention_heads * context_length) / hidden_dimensions))
activations = layers * activations_per_layer / 2 # divided by 2 as per the paper's calculation at 16bit precision
# Convert activations to bytes based on the precision
bytes_per_param = precision_options[precision] / 8
total_activations = bytes_per_param * activations
return total_activations
# User inputs
parameter_count = st.number_input("Parameter Count (in billions)", value=1, step=1) * 1e9
layers = st.number_input("Number of Layers", value=32, step=1)
attention_heads = st.number_input("Number of Attention Heads", value=32, step=1)
context_length = st.number_input("Context Length (number of tokens)", value=512, step=1)
data_type = st.selectbox("Data Type", options=list(quantization_bit_sizes.keys()))
batch_size = st.number_input("Batch Size", value=1, step=1)
vocab_size = st.number_input("Vocabulary Size", value=30000, step=1000)
precision = st.selectbox("Precision", options=list(precision_options.keys()))
# Calculate memory usage
if st.button("Calculate Memory Usage"):
memory_usage = calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision)
st.write(f"Estimated Memory Usage for Inference: {memory_usage:.2f} GB")