Piraloco's picture
Update app.py
86f8cd5
# Import necessary libraries
import torch
import torch.nn as neural_network_module
class Net(neural_network_module.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = neural_network_module.Linear(7, 128)
self.fc2 = neural_network_module.Linear(128, 64)
self.fc3 = neural_network_module.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# Create an instance of the network
# Load the model
model = Net()
model.load_state_dict(torch.load('model.pth'))
agents = [
'Brimstone',
'Viper',
'Omen',
'Killjoy',
'Cypher',
'Sova',
'Sage',
'Phoenix',
'Jett',
'Reyna',
'Raze',
'Breach',
'Skye',
'Yoru',
'Astra',
'KAY/O',
'Chamber',
'Neon',
'Fade',
'Harbor',
'Gekko',
'Deadlock',
'Iso',
]
maps = [
'Ascent',
'Bind',
'Breeze',
'Fracture',
'Haven',
'Icebox',
'Lotus',
'Pearl',
'Split',
'Sunset',
]
ranks = [
'Iron 1',
'Iron 2',
'Iron 3',
'Bronze 1',
'Bronze 2',
'Bronze 3',
'Silver 1',
'Silver 2',
'Silver 3',
'Gold 1',
'Gold 2',
'Gold 3',
'Platinum 1',
'Platinum 2',
'Platinum 3',
'Diamond 1',
'Diamond 2',
'Diamond 3',
'Ascendant 1',
'Ascendant 2',
'Ascendant 3',
'Immortal 1',
'Immortal 2',
'Immortal 3',
'Radiant',
]
def preprocess_data(data):
# Preprocess the data (replace this with your specific preprocessing steps)
data[0] = ranks.index(data[0])
data[1] = maps.index(data[1])
data[2:7] = [agents.index(agent) for agent in data[2:7]]
data = torch.tensor(data, dtype = torch.float32)
return data
# Define your prediction function
def make_prediction(rank,map,agent_picks):
try:
data = [rank,map,agent_picks[0],agent_picks[1],agent_picks[2],agent_picks[3],agent_picks[4]]
# Preprocess the data (replace this with your specific preprocessing steps)
processed_data = preprocess_data(data)
# Feed the data to the model
prediction = model(processed_data)
# Post-process the output (replace this with your specific post-processing steps)
#prediction = model(data)
prediction = prediction.item()
if prediction > 1:
prediction -= (prediction - 1)/2
prediction = 0 if prediction < 0 else prediction
winrate = str(round(prediction * 100)) + '%'
print(f"Calculated Winrate: {winrate}")
return winrate
except:
return 'Error, you probably forgot to fill out a component'
# Example usage
#data = ... # your input data
#prediction = predict(data)
#print(f"Prediction: {prediction}")
import gradio as gr
import gradio_client as gc
# Create Gradio interface with relevant inputs
interface = gr.Interface(
fn=make_prediction,
inputs=[
# Input for rank
gr.Dropdown(label="Rank", choices=ranks),
# Input for map
gr.Dropdown(label="Map", choices=maps),
# Input for agents
gr.Dropdown(label="Agent Picks (1-5)", choices=agents, multiselect=True)
],
outputs="text",
)
interface.launch()