|
|
|
import torch |
|
import torch.nn as neural_network_module |
|
|
|
class Net(neural_network_module.Module): |
|
def __init__(self): |
|
super(Net, self).__init__() |
|
self.fc1 = neural_network_module.Linear(7, 128) |
|
self.fc2 = neural_network_module.Linear(128, 64) |
|
self.fc3 = neural_network_module.Linear(64, 1) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
|
|
|
|
model = Net() |
|
model.load_state_dict(torch.load('model.pth')) |
|
|
|
agents = [ |
|
'Brimstone', |
|
'Viper', |
|
'Omen', |
|
'Killjoy', |
|
'Cypher', |
|
'Sova', |
|
'Sage', |
|
'Phoenix', |
|
'Jett', |
|
'Reyna', |
|
'Raze', |
|
'Breach', |
|
'Skye', |
|
'Yoru', |
|
'Astra', |
|
'KAY/O', |
|
'Chamber', |
|
'Neon', |
|
'Fade', |
|
'Harbor', |
|
'Gekko', |
|
'Deadlock', |
|
'Iso', |
|
] |
|
maps = [ |
|
'Ascent', |
|
'Bind', |
|
'Breeze', |
|
'Fracture', |
|
'Haven', |
|
'Icebox', |
|
'Lotus', |
|
'Pearl', |
|
'Split', |
|
'Sunset', |
|
] |
|
ranks = [ |
|
'Iron 1', |
|
'Iron 2', |
|
'Iron 3', |
|
'Bronze 1', |
|
'Bronze 2', |
|
'Bronze 3', |
|
'Silver 1', |
|
'Silver 2', |
|
'Silver 3', |
|
'Gold 1', |
|
'Gold 2', |
|
'Gold 3', |
|
'Platinum 1', |
|
'Platinum 2', |
|
'Platinum 3', |
|
'Diamond 1', |
|
'Diamond 2', |
|
'Diamond 3', |
|
'Ascendant 1', |
|
'Ascendant 2', |
|
'Ascendant 3', |
|
'Immortal 1', |
|
'Immortal 2', |
|
'Immortal 3', |
|
'Radiant', |
|
] |
|
|
|
|
|
def preprocess_data(data): |
|
|
|
data[0] = ranks.index(data[0]) |
|
data[1] = maps.index(data[1]) |
|
data[2:7] = [agents.index(agent) for agent in data[2:7]] |
|
|
|
data = torch.tensor(data, dtype = torch.float32) |
|
|
|
return data |
|
|
|
|
|
def make_prediction(rank,map,agent_picks): |
|
try: |
|
data = [rank,map,agent_picks[0],agent_picks[1],agent_picks[2],agent_picks[3],agent_picks[4]] |
|
|
|
|
|
processed_data = preprocess_data(data) |
|
|
|
|
|
prediction = model(processed_data) |
|
|
|
|
|
|
|
prediction = prediction.item() |
|
if prediction > 1: |
|
prediction -= (prediction - 1)/2 |
|
prediction = 0 if prediction < 0 else prediction |
|
|
|
winrate = str(round(prediction * 100)) + '%' |
|
|
|
print(f"Calculated Winrate: {winrate}") |
|
|
|
return winrate |
|
except: |
|
return 'Error, you probably forgot to fill out a component' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import gradio_client as gc |
|
|
|
|
|
interface = gr.Interface( |
|
fn=make_prediction, |
|
inputs=[ |
|
|
|
gr.Dropdown(label="Rank", choices=ranks), |
|
|
|
gr.Dropdown(label="Map", choices=maps), |
|
|
|
gr.Dropdown(label="Agent Picks (1-5)", choices=agents, multiselect=True) |
|
], |
|
outputs="text", |
|
) |
|
|
|
interface.launch() |
|
|