|
import torch |
|
import gradio as gr |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class NeuralNetwork(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
""" |
|
Initializes a neural network model. |
|
|
|
Args: |
|
input_size (int): The size of the input layer. |
|
hidden_size (int): The size of the hidden layer. |
|
output_size (int): The size of the output layer. |
|
""" |
|
super(NeuralNetwork, self).__init__() |
|
self.fc1 = nn.Linear(input_size, hidden_size) |
|
self.relu = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, input_image): |
|
""" |
|
Performs a forward pass through the neural network. |
|
|
|
Args: |
|
input_image (torch.Tensor): The input image tensor. |
|
|
|
Returns: |
|
torch.Tensor: The output tensor of the neural network. |
|
""" |
|
input_image = self.relu(self.fc1(input_image)) |
|
input_image = self.fc2(input_image) |
|
return input_image |
|
|
|
|
|
model = NeuralNetwork(14, 64, 2) |
|
model.load_state_dict(torch.load("model.pth")) |
|
|
|
|
|
maps = [ |
|
'Ascent', |
|
'Bind', |
|
'Breeze', |
|
'Fracture', |
|
'Haven', |
|
'Icebox', |
|
'Lotus', |
|
'Pearl', |
|
'Split', |
|
'Sunset', |
|
] |
|
|
|
agents = [ |
|
'Brimstone', |
|
'Viper', |
|
'Omen', |
|
'Killjoy', |
|
'Cypher', |
|
'Sova', |
|
'Sage', |
|
'Phoenix', |
|
'Jett', |
|
'Reyna', |
|
'Raze', |
|
'Breach', |
|
'Skye', |
|
'Yoru', |
|
'Astra', |
|
'Kayo', |
|
'Chamber', |
|
'Neon', |
|
'Fade', |
|
'Harbor', |
|
'Gekko', |
|
'Deadlock', |
|
'Iso', |
|
] |
|
|
|
|
|
|
|
|
|
def predict(*args): |
|
def test_convert(test): |
|
test[3] = maps.index(test[3]) |
|
test[4:9] = [agents.index(index) for index in test[4:9]] |
|
test[9:14] = [agents.index(index) for index in test[9:14]] |
|
|
|
return test |
|
|
|
data = list(args) |
|
data = test_convert(data) |
|
data = torch.tensor(data, dtype=torch.float32) |
|
|
|
outputs = model(data) |
|
highest_score = (torch.max(outputs), torch.argmax(outputs).item()) |
|
|
|
if highest_score[0] < 13: |
|
outputs[highest_score[1]] = 13 |
|
else: |
|
if outputs[1-highest_score[1]] < highest_score[0] - 2: |
|
outputs[1-highest_score[1]] = highest_score[0] - 2 |
|
|
|
score_a = round(outputs[0].item()) |
|
score_b = round(outputs[1].item()) |
|
|
|
return f'Predicted score: {score_a} - {score_b}' |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(min_width="0px", scale=1): |
|
year_input = gr.Number(label="Year", value=23) |
|
with gr.Column(min_width="0px", scale=1): |
|
month_input = gr.Number(label="Month", value=2) |
|
with gr.Column(min_width="0px", scale=1): |
|
day_input = gr.Number(label="Day", value=23) |
|
with gr.Column(scale=3): |
|
map_input = gr.Dropdown(maps, label="Map", value='Ascent') |
|
with gr.Column(scale=3): |
|
pass |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
team1_agent1_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 1", value='Brimstone') |
|
team1_agent2_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 2", value='Viper') |
|
team1_agent3_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 3", value='Omen') |
|
team1_agent4_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 4", value='Killjoy') |
|
team1_agent5_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 5", value='Cypher') |
|
|
|
with gr.Column(): |
|
|
|
team2_agent1_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 1", value='Sova') |
|
team2_agent2_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 2", value='Sage') |
|
team2_agent3_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 3", value='Phoenix') |
|
team2_agent4_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 4", value='Jett') |
|
team2_agent5_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 5", value='Reyna') |
|
|
|
|
|
with gr.Column(): |
|
|
|
translate_btn = gr.Button(value="Translate") |
|
|
|
score_difference_output = gr.Textbox(label="Score Difference") |
|
translate_btn.click(fn=predict, inputs=[year_input, month_input, day_input, map_input, team1_agent1_input, team1_agent2_input, team1_agent3_input, team1_agent4_input, team1_agent5_input, team2_agent1_input, team2_agent2_input, team2_agent3_input, team2_agent4_input, team2_agent5_input], outputs=score_difference_output) |
|
|
|
print('Lauching interface!') |
|
demo.launch() |