test3 / utils.py
basilboy's picture
Update utils.py
23fcfc1 verified
import streamlit as st
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
def validate_sequence(sequence):
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
def load_model(model_name):
# Load the model based on the provided name
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
model.eval()
return model
def predict(model, sequence):
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
output = model(**tokenized_input)
probabilities = F.softmax(output.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1)
confidence = probabilities.max().item() * 0.85
return predicted_label.item(), confidence
def plot_prediction_graphs(data,model_keys):
# Create a color palette that is consistent across graphs
unique_names = sorted(data.keys()) # Using names instead of sequences
palette = sns.color_palette("hsv", len(unique_names))
color_dict = {name: color for name, color in zip(unique_names, palette)}
for model_name in model_keys:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
for prediction_val in [0, 1]:
ax = ax1 if prediction_val == 0 else ax2
filtered_data = {name: values[model_name] for name, values in data.items() if values[model_name][0] == prediction_val}
# Sorting names based on confidence, descending
sorted_names = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
names = [x[0] for x in sorted_names]
conf_values = [x[1][1] for x in sorted_names]
colors = [color_dict[name] for name in names]
sns.barplot(x=names, y=conf_values, palette=colors, ax=ax)
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
ax.set_xlabel('Names')
ax.set_ylabel('Confidence')
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
st.pyplot(fig)