Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer | |
from transformers import BertForSequenceClassification | |
st.set_page_config(layout='wide', initial_sidebar_state='expanded') | |
col1, col2= st.columns(2) | |
with col1: | |
st.title("FireWatch") | |
st.markdown("PREDICT WHETHER HEAT SIGNATURES AROUND THE GLOBE ARE LIKELY TO BE FIRES!") | |
st.markdown("Traing Code at:") | |
st.markdown("https://colab.research.google.com/drive/1-IfOMJ-X8MKzwm3UjbJbK6RmhT7tk_ye?usp=sharing") | |
st.markdown("Try the Model Yourself at:") | |
st.markdown("https://colab.research.google.com/drive/1GmweeQrkzs0OXQ_KNZsWd1PQVRLCWDKi?usp=sharing") | |
st.markdown("## Sample Table") | |
table_html = """ | |
<table style="border-collapse: collapse; width: 100%;"> | |
<tr style="border: 1px solid orange;"> | |
<th style="border: 1px solid orange; font-weight: bold;">Category</th> | |
<th style="border: 1px solid orange; font-weight: bold;">Latitude, Longitude, Brightness, FRP</th> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Likely</td> | |
<td style="border: 1px solid orange;">-26.76123, 147.15512, 393.02, 203.63</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Likely</td> | |
<td style="border: 1px solid orange;">-26.7598, 147.14514, 361.54, 79.4</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Unlikely</td> | |
<td style="border: 1px solid orange;">-25.70059, 149.48932, 313.9, 5.15</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Unlikely</td> | |
<td style="border: 1px solid orange;">-24.4318, 151.83102, 307.98, 8.79</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Unlikely</td> | |
<td style="border: 1px solid orange;">-23.21878, 148.91298, 314.08, 7.4</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Likely</td> | |
<td style="border: 1px solid orange;">7.87518, 19.9241, 316.32, 39.63</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Unlikely</td> | |
<td style="border: 1px solid orange;">-20.10942, 148.14326, 314.39, 8.8</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Unlikely</td> | |
<td style="border: 1px solid orange;">7.87772, 19.9048, 304.14, 13.43</td> | |
</tr> | |
<tr style="border: 1px solid orange;"> | |
<td style="border: 1px solid orange;">Likely</td> | |
<td style="border: 1px solid orange;">-20.79866, 124.46834, 366.74, 89.06</td> | |
</tr> | |
</table> | |
""" | |
st.markdown(table_html, unsafe_allow_html=True) | |
tree = """ | |
<div class="pine-tree" style="width: 50%; margin: 0 auto;"> | |
<div class="tree-top"></div> | |
<div class="tree-top2"></div> | |
<div class="tree-bottom"> | |
<div class="trunk"></div> | |
</div> | |
</div> | |
<style> | |
.pine-tree { | |
width: 15vw; | |
height: 20vw; | |
position: relative; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
} | |
.tree-top { | |
width: 0; | |
height: 0; | |
border-left: 8vw solid transparent; | |
border-right: 8vw solid transparent; | |
border-bottom: 13vw solid green; | |
position: absolute; | |
top: 0; | |
left: 0; | |
right: 0; | |
margin: auto; | |
} | |
.tree-top2 { | |
width: 0; | |
height: 0; | |
border-left: 8vw solid transparent; | |
border-right: 8vw solid transparent; | |
border-bottom: 13vw solid green; | |
position: absolute; | |
top: 3vw; | |
left: 0; | |
right: 0; | |
margin: auto; | |
} | |
.tree-bottom { | |
width: 8vw; | |
height: 10vw; | |
background-color: brown; | |
position: absolute; | |
bottom: 0; | |
left: 0; | |
right: 0; | |
top: 21vw; | |
margin: auto; | |
} | |
.trunk { | |
width: 3vw; | |
height: 10vw; | |
background-color: brown; | |
position: absolute; | |
bottom: 0; | |
left: 0; | |
right: 0; | |
margin: auto; | |
} | |
</style> | |
""" | |
with col2: | |
def load_model(show_spinner=True): | |
MODEL_PATH = "NimaKL/FireWatch_tiny_75k" | |
model = BertForSequenceClassification.from_pretrained(MODEL_PATH) | |
return model | |
token_id = [] | |
attention_masks = [] | |
def preprocessing(input_text, tokenizer): | |
''' | |
Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields: | |
- input_ids: list of token ids | |
- token_type_ids: list of token type ids | |
- attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True). | |
''' | |
return tokenizer.encode_plus( | |
input_text, | |
add_special_tokens = True, | |
max_length = 16, | |
pad_to_max_length = True, | |
return_attention_mask = True, | |
return_tensors = 'pt' | |
) | |
def predict(new_sentence): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# We need Token IDs and Attention Mask for inference on the new sentence | |
test_ids = [] | |
test_attention_mask = [] | |
# Apply the tokenizer | |
encoding = preprocessing(new_sentence, tokenizer) | |
# Extract IDs and Attention Mask | |
test_ids.append(encoding['input_ids']) | |
test_attention_mask.append(encoding['attention_mask']) | |
test_ids = torch.cat(test_ids, dim = 0) | |
test_attention_mask = torch.cat(test_attention_mask, dim = 0) | |
# Forward pass, calculate logit predictions | |
with torch.no_grad(): | |
output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device)) | |
prediction = 'Likely' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Unlikely' | |
pred = 'Predicted Class: '+ prediction | |
return pred | |
model = load_model() | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
with col2: | |
st.markdown('## Enter Prediction Data in Correct Format "Latitude, Longtitude, Brightness, FRP"') | |
text = st.text_input('Predition Data: ', 'Example: 8.81064, -65.07661, 328.04, 18.76') | |
aButton = st.button('Predict') | |
if text or aButton: | |
with st.spinner('Wait for it...'): | |
st.success(predict(text)) | |
st.markdown(tree, unsafe_allow_html=True) | |