dino-ai / app.py
Wakka2905's picture
Update app.py
5bcde72
raw
history blame
25.7 kB
import torch
import streamlit as st
import pygame
import os
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from collections import deque
import random
from typing import List
from argparse import Action
import random
import sys
from sqlalchemy import asc
import math
import time
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt
SCREEN_HEIGHT = 600
SCREEN_WIDTH = 1100
INIT_GAME_SPEED = 14
X_POS_BG_INIT = 0
Y_POS_BG = 380
INIT_REPLAY_MEM_SIZE = 5_000
REPLAY_MEMORY_SIZE = 45_000
MODEL_NAME = "DINO"
MIN_REPLAY_MEMORY_SIZE = 1_000
MINIBATCH_SIZE = 64
DISCOUNT = 0.95
UPDATE_TARGET_THRESH = 5
#EPSILON_INIT = 0.45 epsilon inicial
EPSILON_INIT = 0.25 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio
#EPSILON_DECAY = 0.997 epsilon inicial
EPSILON_DECAY = 0.75 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio
NUM_EPISODES = 100
MIN_EPSILON = 0.05
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")),
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))]
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")),
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))]
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png"))
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")),
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")),
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))]
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")),
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")),
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))]
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))]
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png"))
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png"))
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")),
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))]
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")),
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))]
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png"))
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")),
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")),
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))]
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")),
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")),
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))]
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))]
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png"))
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png"))
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(7, 4) # 7 input features, 4 output features
self.fc2 = nn.Linear(4, 3) # 4 input features, 3 output features
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #Para poder usar GPU
class DQNAgent:
def __init__(self):
self.model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible
self.target_model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible
self.target_model.load_state_dict(self.model.state_dict())
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
self.loss_function = nn.MSELoss()
self.init_replay_memory = deque(maxlen=INIT_REPLAY_MEM_SIZE)
self.late_replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)
self.target_update_counter = 0
# Update the memory store
def update_replay_memory(self, transition):
# if len(self.replay_memory) > 50_000:
# self.replay_memory.clear()
if len(self.init_replay_memory) < INIT_REPLAY_MEM_SIZE:
self.init_replay_memory.append(transition)
else:
self.late_replay_memory.append(transition)
# Método get_qs dentro de la clase DQNAgent
def get_qs(self, state):
state_tensor = torch.Tensor(state).to(device) # Asegúrate de mover el tensor al dispositivo correcto
with torch.no_grad():
return self.model(state_tensor).cpu().numpy() # Luego mueve el resultado de vuelta a la CPU si es necesario
def train(self, terminal_state, step):
if len(self.init_replay_memory) < MIN_REPLAY_MEMORY_SIZE:
return
total_mem = list(self.init_replay_memory)
total_mem.extend(self.late_replay_memory)
minibatch = random.sample(total_mem, MINIBATCH_SIZE)
# Asegurarse de que los tensores estén en el dispositivo correcto
current_states = torch.Tensor([transition[0] for transition in minibatch]).to(device)
current_qs_list = self.model(current_states)
new_current_states = torch.Tensor([transition[3] for transition in minibatch]).to(device)
future_qs_list = self.target_model(new_current_states)
X = []
y = []
for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):
if not done:
max_future_q = torch.max(future_qs_list[index])
new_q = reward + DISCOUNT * max_future_q
else:
new_q = reward
current_qs = current_qs_list[index]
current_qs[action] = new_q
X.append(current_state)
y.append(current_qs)
X = torch.tensor(np.array(X, dtype=np.float32)).to(device) # Mover X a la GPU
y = torch.tensor(np.array([y_item.detach().cpu().numpy() if isinstance(y_item, torch.Tensor) else y_item for y_item in y], dtype=np.float32)).to(device) # Mover y a la GPU
self.optimizer.zero_grad()
output = self.model(X) # X ya está en el dispositivo correcto
loss = self.loss_function(output, y) # y ya está en el dispositivo correcto
loss.backward()
self.optimizer.step()
if terminal_state:
self.target_update_counter += 1
if self.target_update_counter > UPDATE_TARGET_THRESH:
self.target_model.load_state_dict(self.model.state_dict())
self.target_update_counter = 0
# print(self.target_update_counter)
class Obstacle:
def __init__(self, image: List[pygame.Surface], type: int) -> None:
self.image = image
self.type = type
self.rect = self.image[self.type].get_rect()
self.rect.x = SCREEN_WIDTH
def update(self, obstacles: list, game_speed: int):
self.rect.x -= game_speed
if self.rect.x < -self.rect.width:
obstacles.pop()
def draw(self, SCREEN: pygame.Surface):
SCREEN.blit(self.image[self.type], self.rect)
class Dino(DQNAgent):
X_POS = 80
Y_POS = 310
Y_DUCK_POS = 340
JUMP_VEL = 8.5
#code here
def __init__(self) -> None:
#Initializing the images for the dino
self.duck_img = DUCKING
self.run_img = RUNNING
self.jump_img = JUMPING
#Initially the dino starts running
self.dino_duck = False
self.dino_run = True
self.dino_jump = False
self.step_index = 0
self.jump_vel = self.JUMP_VEL
self.image = self.run_img[0]
self.dino_rect = self.image.get_rect()
self.dino_rect.x = self.X_POS
self.dino_rect.y = self.Y_POS
self.score = 0
super().__init__()
# Update the Dino's state
def update(self, move: pygame.key.ScancodeWrapper):
if self.dino_duck:
self.duck()
if self.dino_jump:
self.jump()
if self.dino_run:
self.run()
if self.step_index >= 20:
self.step_index = 0
if move[pygame.K_UP] and not self.dino_jump:
self.dino_jump = True
self.dino_run = False
self.dino_duck = False
elif move[pygame.K_DOWN] and not self.dino_jump:
self.dino_duck = True
self.dino_run = False
self.dino_jump = False
elif not(self.dino_jump or move[pygame.K_DOWN]):
self.dino_run = True
self.dino_jump = False
self.dino_duck = False
def update_auto(self, move):
if self.dino_duck == True:
self.duck()
if self.dino_jump == True:
self.jump()
if self.dino_run == True:
self.run()
if self.step_index >= 20:
self.step_index = 0
if move == 0 and not self.dino_jump:
self.dino_jump = True
self.dino_run = False
self.dino_duck = False
elif move == 1 and not self.dino_jump:
self.dino_duck = True
self.dino_run = False
self.dino_jump = False
elif not(self.dino_jump or move == 1):
self.dino_run = True
self.dino_jump = False
self.dino_duck = False
def duck(self) -> None:
self.image = self.duck_img[self.step_index // 10]
self.dino_rect = self.image.get_rect()
self.dino_rect.x = self.X_POS
self.dino_rect.y = self.Y_DUCK_POS
self.step_index += 1
def run(self) -> None:
self.image = self.run_img[self.step_index // 10]
self.dino_rect = self.image.get_rect()
self.dino_rect.x = self.X_POS
self.dino_rect.y = self.Y_POS
self.step_index += 1
def jump(self) -> None:
self.image = self.jump_img
if self.dino_jump:
self.dino_rect.y -= self.jump_vel * 3
self.jump_vel -= 0.6
if self.jump_vel < -self.JUMP_VEL:
self.dino_jump = False
self.dino_run = True
self.jump_vel = self.JUMP_VEL
def draw(self, SCREEN: pygame.Surface):
SCREEN.blit(self.image, (self.dino_rect.x, self.dino_rect.y))
class LargeCactus(Obstacle):
def __init__(self, image: List[pygame.Surface]) -> None:
self.type = random.randint(0, 2)
super().__init__(image, self.type)
self.rect.y = 300
class SmallCactus(Obstacle):
def __init__(self, image: List[pygame.Surface]) -> None:
self.type = random.randint(0, 2)
super().__init__(image, self.type)
self.rect.y = 325
class Bird(Obstacle):
def __init__(self, image: List[pygame.Surface]) -> None:
self.type = 0
super().__init__(image, self.type)
self.rect.y = SCREEN_HEIGHT - 340
self.index = 0
def draw(self, SCREEN: pygame.Surface):
if self.index >= 19:
self.index = 0
SCREEN.blit(self.image[self.index // 10], self.rect)
self.index += 1
class Cloud:
def __init__(self) -> None:
self.x = SCREEN_WIDTH + random.randint(800, 1000)
self.y = random.randint(50, 100)
self.image = CLOUD
self.width = self.image.get_width()
def update(self, game_speed: int):
self.x -= game_speed
if self.x < -self.width:
self.x = SCREEN_WIDTH + random.randint(800, 1000)
self.y = random.randint(50, 100)
def draw(self, SCREEN: pygame.Surface):
SCREEN.blit(self.image, (self.x, self.y))
class Game:
def __init__(self, epsilon, load_model=False, model_path=None):
os.environ["SDL_VIDEODRIVER"] = "dummy" # Establece el driver de video de SDL a dummy
pygame.init()
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
self.obstacles = []
self.run = True
self.clock = pygame.time.Clock()
self.cloud = Cloud()
self.game_speed = INIT_GAME_SPEED
self.font = pygame.font.Font("freesansbold.ttf", 20)
self.dino = Dino()
# Cargar el modelo si se solicita
if load_model and model_path:
self.dino.model.load_state_dict(torch.load(model_path, map_location=device))
self.x_pos_bg = X_POS_BG_INIT
self.points = 0
self.epsilon = epsilon
self.ep_rewards = [-200]
self.high_score = 0 # Inicializa el high score con 0 o carga el high score existente de un archivo si lo prefieres
self.best_score = 0
def reset(self):
self.game_speed = INIT_GAME_SPEED
old_dino = self.dino
self.dino = Dino()
self.dino.init_replay_memory = old_dino.init_replay_memory
self.dino.late_replay_memory = old_dino.late_replay_memory
self.dino.target_update_counter = old_dino.target_update_counter
self.dino.model.load_state_dict(old_dino.model.state_dict())
self.dino.target_model.load_state_dict(old_dino.target_model.state_dict())
self.x_pos_bg = X_POS_BG_INIT
self.points = 0
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
self.clock = pygame.time.Clock()
def get_dist(self, pos_a: tuple, pos_b:tuple):
dx = pos_a[0] - pos_b[0]
dy = pos_a[1] - pos_b[1]
return math.sqrt(dx**2 + dy**2)
def update_background(self):
image_width = BACKGROUND.get_width()
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg, Y_POS_BG))
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG))
if self.x_pos_bg <= -image_width:
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG))
self.x_pos_bg = 0
self.x_pos_bg -= self.game_speed
return self.x_pos_bg
def get_state(self):
state = []
state.append(self.dino.dino_rect.y / self.dino.Y_DUCK_POS + 10)
pos_a = (self.dino.dino_rect.x, self.dino.dino_rect.y)
bird = 0
cactus = 0
if len(self.obstacles) == 0:
dist = self.get_dist(pos_a, tuple([SCREEN_WIDTH + 10, self.dino.Y_POS])) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2)
obs_height = 0
obj_width = 0
else:
dist = self.get_dist(pos_a, (self.obstacles[0].rect.midtop)) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2)
obs_height = self.obstacles[0].rect.midtop[1] / self.dino.Y_DUCK_POS
obj_width = self.obstacles[0].rect.width / SMALL_CACTUS[2].get_rect().width
if self.obstacles[0].__class__ == SmallCactus(SMALL_CACTUS).__class__ or \
self.obstacles[0].__class__ == LargeCactus(LARGE_CACTUS).__class__:
cactus = 1
else:
bird = 1
state.append(dist)
state.append(obs_height)
state.append(self.game_speed / 24)
state.append(obj_width)
state.append(cactus)
state.append(bird)
return state
def update_score(self):
self.points += 1
if self.points % 200 == 0:
self.game_speed += 1
if self.points > self.high_score:
self.high_score = self.points
text = self.font.render(f"Points: {self.points} Highscore: {self.high_score}", True, (0, 0, 0))
textRect = text.get_rect()
textRect.center = (SCREEN_WIDTH - textRect.width // 2 - 10, 40)
self.SCREEN.blit(text, textRect)
def create_obstacle(self):
# bird_prob = random.randint(0, 15)
# cactus_prob = random.randint(0, 10)
# if bird_prob == 0:
# self.obstacles.append(Bird(BIRD))
# elif cactus_prob == 0:
# self.obstacles.append(SmallCactus(SMALL_CACTUS))
# elif cactus_prob == 1:
# self.obstacles.append(LargeCactus(LARGE_CACTUS))
obstacle_prob = random.randint(0, 50)
if obstacle_prob == 0:
self.obstacles.append(SmallCactus(SMALL_CACTUS))
elif obstacle_prob == 1:
self.obstacles.append(LargeCactus(LARGE_CACTUS))
elif obstacle_prob == 2 and self.points > 300:
self.obstacles.append(Bird(BIRD))
def update_game(self, moves, user_input=None):
self.dino.draw(self.SCREEN)
if user_input is not None:
self.dino.update(user_input)
else:
self.dino.update_auto(moves)
self.update_background()
self.cloud.draw(self.SCREEN)
self.cloud.update(self.game_speed)
self.update_score()
self.clock.tick(30)
# pygame.display.update()
def play_manual(self):
while self.run is True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit()
self.SCREEN.fill((255, 255, 255))
user_input = pygame.key.get_pressed()
# moves = []
if len(self.obstacles) == 0:
self.create_obstacle()
for obstacle in self.obstacles:
obstacle.draw(SCREEN=self.SCREEN)
obstacle.update(self.obstacles, self.game_speed)
if self.dino.dino_rect.colliderect(obstacle.rect):
self.dino.score = self.points
pygame.quit()
self.obstacles.pop()
print("Game over!")
return
self.update_game(user_input=user_input, moves=2)
pygame.display.update()
def play_auto(self, episode_info, metrics_container, logs_container):
try:
points_label = 0
self.ep_rewards = []
for episode in range(1, NUM_EPISODES + 1):
self.points = 0
episode_reward = 0
step = 1
current_state = self.get_state()
self.run = True
logs = [] # Lista para almacenar logs de este episodio
while self.run is True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit()
self.SCREEN.fill((255, 255, 255))
if len(self.obstacles) == 0:
self.create_obstacle()
# if self.run == False:
# print(current_state)
# time.sleep(2)
# continue
if np.random.random() > self.epsilon:
action = self.dino.get_qs(torch.Tensor(current_state))
# print(action)
action = np.argmax(action)
# print(action)
else:
num = np.random.randint(0, 10)
if num == 0:
# print("yes")
action = num
elif num <= 3:
action = 1
else:
action = 2
self.update_game(moves=action)
# print(self.game_speed)
next_state = self.get_state()
reward = 0
for obstacle in self.obstacles:
obstacle.draw(SCREEN=self.SCREEN)
obstacle.update(self.obstacles, self.game_speed)
next_state = self.get_state()
if self.dino.dino_rect.x > obstacle.rect.x + obstacle.rect.width:
reward = 3
if action == 0 and obstacle.rect.x > SCREEN_WIDTH // 2:
reward = -1
if self.dino.dino_rect.colliderect(obstacle.rect):
self.dino.score = self.points
# pygame.quit()
self.obstacles.pop()
points_label = self.points
reward = -10
# print("Game over!")
self.run = False
break
# if reward != 0:
# print(reward > 0)
episode_reward += reward
self.dino.update_replay_memory(tuple([current_state, action, reward, next_state, self.run]))
self.dino.train( not self.run, step=step)
current_state = next_state
step += 1
# self.clock.tick(60)
#print(self.points)
#print(self.high_score)
# Actualiza la interfaz de Streamlit
episode_info.text(f'Episodio: {episode}, Puntuación actual: {self.points}, Recompensa del episodio: {episode_reward}')
# Al final de cada episodio, verifica si hay un nuevo mejor puntaje
if self.points > self.best_score:
self.best_score = self.points
# Este archivo se sobrescribirá con el último mejor modelo
self.best_model_filename = 'models/highscore/BestScore_model.pth'
torch.save(self.dino.model.state_dict(), self.best_model_filename)
pygame.display.update()
self.ep_rewards.append(episode_reward)
if episode % 10 == 0:
with metrics_container:
plot_rewards(self.ep_rewards)
with logs_container:
st.text(f"Resumen del Episodio {episode}: Puntuación final {points_label}, Recompensa total {episode_reward}")
self.reset()
# Obtenemos la fecha y hora actual
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# Guardar el modelo cada 50 escenarios
if episode % 50 == 0:
filename = f'models/episodes/{points_label}_Points,Episode_{episode}_Date_{current_time}_model.pth'
torch.save(self.dino.model.state_dict(), filename)
if self.epsilon > MIN_EPSILON:
self.epsilon *= EPSILON_DECAY
if self.epsilon < MIN_EPSILON:
self.epsilon = 0
# print(self.epsilon)
else:
self.epsilon = max(MIN_EPSILON, self.epsilon)
# print(self.epsilon)
# print((self.dino.replay_memory))
finally:
# Este bloque se ejecutará incluso si se interrumpe el juego.
# Aquí duplicas el archivo del mejor puntaje alcanzado hasta ahora.
if hasattr(self, 'best_model_filename'):
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
final_model_filename = f'models/highscore/{self.best_score}_BestScore_Final_{current_time}_model.pth'
import shutil
shutil.copy(self.best_model_filename, final_model_filename)
print(f"Modelo duplicado guardado como: {final_model_filename}")
def plot_rewards(ep_rewards):
plt.figure(figsize=(10, 6))
plt.plot(ep_rewards)
plt.title("Recompensas por Episodio")
plt.xlabel("Episodio")
plt.ylabel("Recompensa")
st.pyplot(plt)
# Streamlit UI
def streamlit_ui():
st.title('Juego del Dinosaurio con IA')
# Barra lateral para configuraciones
with st.sidebar:
st.header("Configuraciones")
epsilon_init = st.slider("Epsilon Inicial", 0.025, 0.975, EPSILON_INIT)
epsilon_decay = st.slider("Epsilon Decay", 0.025, 0.975, EPSILON_DECAY)
num_episodes = st.slider("Número de Episodios", 1, 500, NUM_EPISODES)
# Seleccionar modelo
model_directory = 'models/highscore/'
model_files = os.listdir(model_directory)
selected_model_file = st.selectbox('Elige un modelo para cargar', model_files)
# Mostrar métricas
score_col, highscore_col = st.columns(2)
with score_col:
score = st.empty() # Usar .empty() para actualizar más tarde
with highscore_col:
high_score = st.empty() # Usar .empty() para actualizar más tarde
# Placeholder para mostrar el número de escenario y el resultado
episode_info = st.empty()
# Contenedores para métricas y logs
metrics_container = st.container()
logs_container = st.container()
episode_info = st.empty()
if st.button('Iniciar Juego con IA'):
model_path = os.path.join(model_directory, selected_model_file)
game = Game(EPSILON_INIT, load_model=True, model_path=model_path)
game.play_auto(episode_info, metrics_container, logs_container)
# Ejecutar UI
streamlit_ui()