File size: 1,015 Bytes
894cec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np

from gym.vector.sync_vector_env import SyncVectorEnv
from stable_baselines3.common.vec_env.base_vec_env import tile_images
from typing import Optional

from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecotarableWrapper,
)


class SyncVectorEnvRenderCompat(VecotarableWrapper):
    def __init__(self, env) -> None:
        super().__init__(env)

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        base_env = self.env.unwrapped
        if isinstance(base_env, SyncVectorEnv):
            imgs = [env.render(mode="rgb_array") for env in base_env.envs]
            bigimg = tile_images(imgs)
            if mode == "human":
                import cv2

                cv2.imshow("vecenv", bigimg[:, :, ::-1])
                cv2.waitKey(1)
            elif mode == "rgb_array":
                return bigimg
            else:
                raise NotImplemented(f"Render mode {mode} is not supported")
        else:
            return self.env.render(mode=mode)