Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
import pygame | |
import os | |
import torch.nn as nn | |
import torch.optim as optim | |
import pandas as pd | |
import numpy as np | |
from collections import deque | |
import random | |
from typing import List | |
from argparse import Action | |
import random | |
import sys | |
from sqlalchemy import asc | |
import math | |
import time | |
from tqdm import tqdm | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
SCREEN_HEIGHT = 600 | |
SCREEN_WIDTH = 1100 | |
INIT_GAME_SPEED = 14 | |
X_POS_BG_INIT = 0 | |
Y_POS_BG = 380 | |
INIT_REPLAY_MEM_SIZE = 5_000 | |
REPLAY_MEMORY_SIZE = 45_000 | |
MODEL_NAME = "DINO" | |
MIN_REPLAY_MEMORY_SIZE = 1_000 | |
MINIBATCH_SIZE = 64 | |
DISCOUNT = 0.95 | |
UPDATE_TARGET_THRESH = 5 | |
#EPSILON_INIT = 0.45 epsilon inicial | |
EPSILON_INIT = 0.25 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio | |
#EPSILON_DECAY = 0.997 epsilon inicial | |
EPSILON_DECAY = 0.75 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio | |
NUM_EPISODES = 100 | |
MIN_EPSILON = 0.05 | |
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")), | |
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))] | |
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")), | |
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))] | |
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png")) | |
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))] | |
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))] | |
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))] | |
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png")) | |
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png")) | |
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")), | |
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))] | |
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")), | |
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))] | |
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png")) | |
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))] | |
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")), | |
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))] | |
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))] | |
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png")) | |
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png")) | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super(NeuralNetwork, self).__init__() | |
self.fc1 = nn.Linear(7, 4) # 7 input features, 4 output features | |
self.fc2 = nn.Linear(4, 3) # 4 input features, 3 output features | |
def forward(self, x): | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #Para poder usar GPU | |
class DQNAgent: | |
def __init__(self): | |
self.model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible | |
self.target_model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible | |
self.target_model.load_state_dict(self.model.state_dict()) | |
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) | |
self.loss_function = nn.MSELoss() | |
self.init_replay_memory = deque(maxlen=INIT_REPLAY_MEM_SIZE) | |
self.late_replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE) | |
self.target_update_counter = 0 | |
# Update the memory store | |
def update_replay_memory(self, transition): | |
# if len(self.replay_memory) > 50_000: | |
# self.replay_memory.clear() | |
if len(self.init_replay_memory) < INIT_REPLAY_MEM_SIZE: | |
self.init_replay_memory.append(transition) | |
else: | |
self.late_replay_memory.append(transition) | |
# Método get_qs dentro de la clase DQNAgent | |
def get_qs(self, state): | |
state_tensor = torch.Tensor(state).to(device) # Asegúrate de mover el tensor al dispositivo correcto | |
with torch.no_grad(): | |
return self.model(state_tensor).cpu().numpy() # Luego mueve el resultado de vuelta a la CPU si es necesario | |
def train(self, terminal_state, step): | |
if len(self.init_replay_memory) < MIN_REPLAY_MEMORY_SIZE: | |
return | |
total_mem = list(self.init_replay_memory) | |
total_mem.extend(self.late_replay_memory) | |
minibatch = random.sample(total_mem, MINIBATCH_SIZE) | |
# Asegurarse de que los tensores estén en el dispositivo correcto | |
current_states = torch.Tensor([transition[0] for transition in minibatch]).to(device) | |
current_qs_list = self.model(current_states) | |
new_current_states = torch.Tensor([transition[3] for transition in minibatch]).to(device) | |
future_qs_list = self.target_model(new_current_states) | |
X = [] | |
y = [] | |
for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch): | |
if not done: | |
max_future_q = torch.max(future_qs_list[index]) | |
new_q = reward + DISCOUNT * max_future_q | |
else: | |
new_q = reward | |
current_qs = current_qs_list[index] | |
current_qs[action] = new_q | |
X.append(current_state) | |
y.append(current_qs) | |
X = torch.tensor(np.array(X, dtype=np.float32)).to(device) # Mover X a la GPU | |
y = torch.tensor(np.array([y_item.detach().cpu().numpy() if isinstance(y_item, torch.Tensor) else y_item for y_item in y], dtype=np.float32)).to(device) # Mover y a la GPU | |
self.optimizer.zero_grad() | |
output = self.model(X) # X ya está en el dispositivo correcto | |
loss = self.loss_function(output, y) # y ya está en el dispositivo correcto | |
loss.backward() | |
self.optimizer.step() | |
if terminal_state: | |
self.target_update_counter += 1 | |
if self.target_update_counter > UPDATE_TARGET_THRESH: | |
self.target_model.load_state_dict(self.model.state_dict()) | |
self.target_update_counter = 0 | |
# print(self.target_update_counter) | |
class Obstacle: | |
def __init__(self, image: List[pygame.Surface], type: int) -> None: | |
self.image = image | |
self.type = type | |
self.rect = self.image[self.type].get_rect() | |
self.rect.x = SCREEN_WIDTH | |
def update(self, obstacles: list, game_speed: int): | |
self.rect.x -= game_speed | |
if self.rect.x < -self.rect.width: | |
obstacles.pop() | |
def draw(self, SCREEN: pygame.Surface): | |
SCREEN.blit(self.image[self.type], self.rect) | |
class Dino(DQNAgent): | |
X_POS = 80 | |
Y_POS = 310 | |
Y_DUCK_POS = 340 | |
JUMP_VEL = 8.5 | |
#code here | |
def __init__(self) -> None: | |
#Initializing the images for the dino | |
self.duck_img = DUCKING | |
self.run_img = RUNNING | |
self.jump_img = JUMPING | |
#Initially the dino starts running | |
self.dino_duck = False | |
self.dino_run = True | |
self.dino_jump = False | |
self.step_index = 0 | |
self.jump_vel = self.JUMP_VEL | |
self.image = self.run_img[0] | |
self.dino_rect = self.image.get_rect() | |
self.dino_rect.x = self.X_POS | |
self.dino_rect.y = self.Y_POS | |
self.score = 0 | |
super().__init__() | |
# Update the Dino's state | |
def update(self, move: pygame.key.ScancodeWrapper): | |
if self.dino_duck: | |
self.duck() | |
if self.dino_jump: | |
self.jump() | |
if self.dino_run: | |
self.run() | |
if self.step_index >= 20: | |
self.step_index = 0 | |
if move[pygame.K_UP] and not self.dino_jump: | |
self.dino_jump = True | |
self.dino_run = False | |
self.dino_duck = False | |
elif move[pygame.K_DOWN] and not self.dino_jump: | |
self.dino_duck = True | |
self.dino_run = False | |
self.dino_jump = False | |
elif not(self.dino_jump or move[pygame.K_DOWN]): | |
self.dino_run = True | |
self.dino_jump = False | |
self.dino_duck = False | |
def update_auto(self, move): | |
if self.dino_duck == True: | |
self.duck() | |
if self.dino_jump == True: | |
self.jump() | |
if self.dino_run == True: | |
self.run() | |
if self.step_index >= 20: | |
self.step_index = 0 | |
if move == 0 and not self.dino_jump: | |
self.dino_jump = True | |
self.dino_run = False | |
self.dino_duck = False | |
elif move == 1 and not self.dino_jump: | |
self.dino_duck = True | |
self.dino_run = False | |
self.dino_jump = False | |
elif not(self.dino_jump or move == 1): | |
self.dino_run = True | |
self.dino_jump = False | |
self.dino_duck = False | |
def duck(self) -> None: | |
self.image = self.duck_img[self.step_index // 10] | |
self.dino_rect = self.image.get_rect() | |
self.dino_rect.x = self.X_POS | |
self.dino_rect.y = self.Y_DUCK_POS | |
self.step_index += 1 | |
def run(self) -> None: | |
self.image = self.run_img[self.step_index // 10] | |
self.dino_rect = self.image.get_rect() | |
self.dino_rect.x = self.X_POS | |
self.dino_rect.y = self.Y_POS | |
self.step_index += 1 | |
def jump(self) -> None: | |
self.image = self.jump_img | |
if self.dino_jump: | |
self.dino_rect.y -= self.jump_vel * 3 | |
self.jump_vel -= 0.6 | |
if self.jump_vel < -self.JUMP_VEL: | |
self.dino_jump = False | |
self.dino_run = True | |
self.jump_vel = self.JUMP_VEL | |
def draw(self, SCREEN: pygame.Surface): | |
SCREEN.blit(self.image, (self.dino_rect.x, self.dino_rect.y)) | |
class LargeCactus(Obstacle): | |
def __init__(self, image: List[pygame.Surface]) -> None: | |
self.type = random.randint(0, 2) | |
super().__init__(image, self.type) | |
self.rect.y = 300 | |
class SmallCactus(Obstacle): | |
def __init__(self, image: List[pygame.Surface]) -> None: | |
self.type = random.randint(0, 2) | |
super().__init__(image, self.type) | |
self.rect.y = 325 | |
class Bird(Obstacle): | |
def __init__(self, image: List[pygame.Surface]) -> None: | |
self.type = 0 | |
super().__init__(image, self.type) | |
self.rect.y = SCREEN_HEIGHT - 340 | |
self.index = 0 | |
def draw(self, SCREEN: pygame.Surface): | |
if self.index >= 19: | |
self.index = 0 | |
SCREEN.blit(self.image[self.index // 10], self.rect) | |
self.index += 1 | |
class Cloud: | |
def __init__(self) -> None: | |
self.x = SCREEN_WIDTH + random.randint(800, 1000) | |
self.y = random.randint(50, 100) | |
self.image = CLOUD | |
self.width = self.image.get_width() | |
def update(self, game_speed: int): | |
self.x -= game_speed | |
if self.x < -self.width: | |
self.x = SCREEN_WIDTH + random.randint(800, 1000) | |
self.y = random.randint(50, 100) | |
def draw(self, SCREEN: pygame.Surface): | |
SCREEN.blit(self.image, (self.x, self.y)) | |
class Game: | |
def __init__(self, epsilon, load_model=False, model_path=None): | |
os.environ["SDL_VIDEODRIVER"] = "dummy" # Establece el driver de video de SDL a dummy | |
pygame.init() | |
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) | |
self.obstacles = [] | |
self.run = True | |
self.clock = pygame.time.Clock() | |
self.cloud = Cloud() | |
self.game_speed = INIT_GAME_SPEED | |
self.font = pygame.font.Font("freesansbold.ttf", 20) | |
self.dino = Dino() | |
# Cargar el modelo si se solicita | |
if load_model and model_path: | |
self.dino.model.load_state_dict(torch.load(model_path, map_location=device)) | |
self.x_pos_bg = X_POS_BG_INIT | |
self.points = 0 | |
self.epsilon = epsilon | |
self.ep_rewards = [-200] | |
self.high_score = 0 # Inicializa el high score con 0 o carga el high score existente de un archivo si lo prefieres | |
self.best_score = 0 | |
def reset(self): | |
self.game_speed = INIT_GAME_SPEED | |
old_dino = self.dino | |
self.dino = Dino() | |
self.dino.init_replay_memory = old_dino.init_replay_memory | |
self.dino.late_replay_memory = old_dino.late_replay_memory | |
self.dino.target_update_counter = old_dino.target_update_counter | |
self.dino.model.load_state_dict(old_dino.model.state_dict()) | |
self.dino.target_model.load_state_dict(old_dino.target_model.state_dict()) | |
self.x_pos_bg = X_POS_BG_INIT | |
self.points = 0 | |
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) | |
self.clock = pygame.time.Clock() | |
def get_dist(self, pos_a: tuple, pos_b:tuple): | |
dx = pos_a[0] - pos_b[0] | |
dy = pos_a[1] - pos_b[1] | |
return math.sqrt(dx**2 + dy**2) | |
def update_background(self): | |
image_width = BACKGROUND.get_width() | |
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg, Y_POS_BG)) | |
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG)) | |
if self.x_pos_bg <= -image_width: | |
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG)) | |
self.x_pos_bg = 0 | |
self.x_pos_bg -= self.game_speed | |
return self.x_pos_bg | |
def get_state(self): | |
state = [] | |
state.append(self.dino.dino_rect.y / self.dino.Y_DUCK_POS + 10) | |
pos_a = (self.dino.dino_rect.x, self.dino.dino_rect.y) | |
bird = 0 | |
cactus = 0 | |
if len(self.obstacles) == 0: | |
dist = self.get_dist(pos_a, tuple([SCREEN_WIDTH + 10, self.dino.Y_POS])) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2) | |
obs_height = 0 | |
obj_width = 0 | |
else: | |
dist = self.get_dist(pos_a, (self.obstacles[0].rect.midtop)) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2) | |
obs_height = self.obstacles[0].rect.midtop[1] / self.dino.Y_DUCK_POS | |
obj_width = self.obstacles[0].rect.width / SMALL_CACTUS[2].get_rect().width | |
if self.obstacles[0].__class__ == SmallCactus(SMALL_CACTUS).__class__ or \ | |
self.obstacles[0].__class__ == LargeCactus(LARGE_CACTUS).__class__: | |
cactus = 1 | |
else: | |
bird = 1 | |
state.append(dist) | |
state.append(obs_height) | |
state.append(self.game_speed / 24) | |
state.append(obj_width) | |
state.append(cactus) | |
state.append(bird) | |
return state | |
def update_score(self): | |
self.points += 1 | |
if self.points % 200 == 0: | |
self.game_speed += 1 | |
if self.points > self.high_score: | |
self.high_score = self.points | |
text = self.font.render(f"Points: {self.points} Highscore: {self.high_score}", True, (0, 0, 0)) | |
textRect = text.get_rect() | |
textRect.center = (SCREEN_WIDTH - textRect.width // 2 - 10, 40) | |
self.SCREEN.blit(text, textRect) | |
def create_obstacle(self): | |
# bird_prob = random.randint(0, 15) | |
# cactus_prob = random.randint(0, 10) | |
# if bird_prob == 0: | |
# self.obstacles.append(Bird(BIRD)) | |
# elif cactus_prob == 0: | |
# self.obstacles.append(SmallCactus(SMALL_CACTUS)) | |
# elif cactus_prob == 1: | |
# self.obstacles.append(LargeCactus(LARGE_CACTUS)) | |
obstacle_prob = random.randint(0, 50) | |
if obstacle_prob == 0: | |
self.obstacles.append(SmallCactus(SMALL_CACTUS)) | |
elif obstacle_prob == 1: | |
self.obstacles.append(LargeCactus(LARGE_CACTUS)) | |
elif obstacle_prob == 2 and self.points > 300: | |
self.obstacles.append(Bird(BIRD)) | |
def update_game(self, moves, user_input=None): | |
self.dino.draw(self.SCREEN) | |
if user_input is not None: | |
self.dino.update(user_input) | |
else: | |
self.dino.update_auto(moves) | |
self.update_background() | |
self.cloud.draw(self.SCREEN) | |
self.cloud.update(self.game_speed) | |
self.update_score() | |
self.clock.tick(30) | |
# pygame.display.update() | |
def play_manual(self): | |
while self.run is True: | |
for event in pygame.event.get(): | |
if event.type == pygame.QUIT: | |
sys.exit() | |
self.SCREEN.fill((255, 255, 255)) | |
user_input = pygame.key.get_pressed() | |
# moves = [] | |
if len(self.obstacles) == 0: | |
self.create_obstacle() | |
for obstacle in self.obstacles: | |
obstacle.draw(SCREEN=self.SCREEN) | |
obstacle.update(self.obstacles, self.game_speed) | |
if self.dino.dino_rect.colliderect(obstacle.rect): | |
self.dino.score = self.points | |
pygame.quit() | |
self.obstacles.pop() | |
print("Game over!") | |
return | |
self.update_game(user_input=user_input, moves=2) | |
pygame.display.update() | |
def play_auto(self,episode_info): | |
try: | |
points_label = 0 | |
for episode in tqdm(range(1, NUM_EPISODES + 1), ascii=True, unit='episodes'): | |
episode_reward = 0 | |
step = 1 | |
current_state = self.get_state() | |
self.run = True | |
# Al final de cada episodio, actualiza la interfaz de Streamlit | |
episode_info.text(f'Escenario: {episode}, Puntuación actual: {self.points}, Recompensa del episodio: {episode_reward}') | |
while self.run is True: | |
for event in pygame.event.get(): | |
if event.type == pygame.QUIT: | |
sys.exit() | |
self.SCREEN.fill((255, 255, 255)) | |
if len(self.obstacles) == 0: | |
self.create_obstacle() | |
# if self.run == False: | |
# print(current_state) | |
# time.sleep(2) | |
# continue | |
if np.random.random() > self.epsilon: | |
action = self.dino.get_qs(torch.Tensor(current_state)) | |
# print(action) | |
action = np.argmax(action) | |
# print(action) | |
else: | |
num = np.random.randint(0, 10) | |
if num == 0: | |
# print("yes") | |
action = num | |
elif num <= 3: | |
action = 1 | |
else: | |
action = 2 | |
self.update_game(moves=action) | |
# print(self.game_speed) | |
next_state = self.get_state() | |
reward = 0 | |
for obstacle in self.obstacles: | |
obstacle.draw(SCREEN=self.SCREEN) | |
obstacle.update(self.obstacles, self.game_speed) | |
next_state = self.get_state() | |
if self.dino.dino_rect.x > obstacle.rect.x + obstacle.rect.width: | |
reward = 3 | |
if action == 0 and obstacle.rect.x > SCREEN_WIDTH // 2: | |
reward = -1 | |
if self.dino.dino_rect.colliderect(obstacle.rect): | |
self.dino.score = self.points | |
# pygame.quit() | |
self.obstacles.pop() | |
points_label = self.points | |
self.reset() | |
reward = -10 | |
# print("Game over!") | |
self.run = False | |
break | |
# if reward != 0: | |
# print(reward > 0) | |
episode_reward += reward | |
self.dino.update_replay_memory(tuple([current_state, action, reward, next_state, self.run])) | |
self.dino.train( not self.run, step=step) | |
current_state = next_state | |
step += 1 | |
# self.clock.tick(60) | |
#print(self.points) | |
#print(self.high_score) | |
# Al final de cada episodio, verifica si hay un nuevo mejor puntaje | |
if self.points > self.best_score: | |
self.best_score = self.points | |
# Este archivo se sobrescribirá con el último mejor modelo | |
self.best_model_filename = 'models/highscore/BestScore_model.pth' | |
torch.save(self.dino.model.state_dict(), self.best_model_filename) | |
pygame.display.update() | |
self.ep_rewards.append(episode_reward) | |
# Obtenemos la fecha y hora actual | |
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
# Guardar el modelo cada 50 escenarios | |
if episode % 50 == 0: | |
filename = f'models/episodes/{points_label}_Points,Episode_{episode}_Date_{current_time}_model.pth' | |
torch.save(self.dino.model.state_dict(), filename) | |
if self.epsilon > MIN_EPSILON: | |
self.epsilon *= EPSILON_DECAY | |
if self.epsilon < MIN_EPSILON: | |
self.epsilon = 0 | |
# print(self.epsilon) | |
else: | |
self.epsilon = max(MIN_EPSILON, self.epsilon) | |
# print(self.epsilon) | |
# print((self.dino.replay_memory)) | |
finally: | |
# Este bloque se ejecutará incluso si se interrumpe el juego. | |
# Aquí duplicas el archivo del mejor puntaje alcanzado hasta ahora. | |
if hasattr(self, 'best_model_filename'): | |
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
final_model_filename = f'models/highscore/{self.best_score}_BestScore_Final_{current_time}_model.pth' | |
import shutil | |
shutil.copy(self.best_model_filename, final_model_filename) | |
print(f"Modelo duplicado guardado como: {final_model_filename}") | |
def plot_rewards(ep_rewards): | |
plt.figure(figsize=(10, 6)) | |
plt.plot(ep_rewards) | |
plt.title("Recompensas por Episodio") | |
plt.xlabel("Episodio") | |
plt.ylabel("Recompensa") | |
st.pyplot(plt) | |
# Streamlit UI | |
def streamlit_ui(): | |
st.title('Juego del Dinosaurio con IA') | |
# Barra lateral para configuraciones | |
with st.sidebar: | |
st.header("Configuraciones") | |
epsilon_init = st.slider("Epsilon Inicial", 0.025, 0.975, EPSILON_INIT) | |
epsilon_decay = st.slider("Epsilon Decay", 0.025, 0.975, EPSILON_DECAY) | |
num_episodes = st.slider("Número de Episodios", 1, 500, NUM_EPISODES) | |
# Seleccionar modelo | |
model_directory = 'models/highscore/' | |
model_files = os.listdir(model_directory) | |
selected_model_file = st.selectbox('Elige un modelo para cargar', model_files) | |
# Mostrar métricas | |
score_col, highscore_col = st.columns(2) | |
with score_col: | |
score = st.empty() # Usar .empty() para actualizar más tarde | |
with highscore_col: | |
high_score = st.empty() # Usar .empty() para actualizar más tarde | |
# Placeholder para mostrar el número de escenario y el resultado | |
episode_info = st.empty() | |
# Botón para iniciar el juego | |
if st.button('Iniciar Juego con IA'): | |
model_path = os.path.join(model_directory, selected_model_file) | |
game = Game(EPSILON_INIT, load_model=True, model_path=model_path) | |
game.play_auto(episode_info) | |
# Llama a esta función después de que `play_auto` haya terminado | |
if len(game.ep_rewards) > 0: | |
plot_rewards(game.ep_rewards) | |
# Ejecutar UI | |
streamlit_ui() |