qgallouedec HF staff commited on
Commit
9575e16
1 Parent(s): 96424ac

update demo

Browse files
texts/getting_my_agent_evaluated.md CHANGED
@@ -93,7 +93,7 @@ class Agent(nn.Module):
93
  agent = Agent(policy) # instantiate the agent
94
 
95
  # A few tests to check if the agent is working
96
- observations = torch.tensor(env.observation_space.sample()).unsqueeze(0) # dummy batch of observations
97
  actions = agent(observations)
98
  actions = actions.numpy()[0]
99
  assert env.action_space.contains(actions)
@@ -109,10 +109,9 @@ from huggingface_hub import metadata_save, HfApi
109
 
110
  # Save model along with its card
111
  metadata_save("model_card.md", {"tags": ["reinforcement-learning", env_id]})
112
- dummy_input = torch.tensor(env.observation_space.sample()).unsqueeze(0) # dummy batch of observations
113
  agent = torch.jit.trace(agent.eval(), dummy_input)
114
- agent = torch.jit.freeze(agent) # required for for the model not to depend on the training library
115
- agent = torch.jit.optimize_for_inference(agent)
116
  torch.jit.save(agent, "agent.pt")
117
 
118
  # Upload model and card to the 🤗 Hub
 
93
  agent = Agent(policy) # instantiate the agent
94
 
95
  # A few tests to check if the agent is working
96
+ observations = torch.randn(env.observation_space.shape).unsqueeze(0) # dummy batch of observations
97
  actions = agent(observations)
98
  actions = actions.numpy()[0]
99
  assert env.action_space.contains(actions)
 
109
 
110
  # Save model along with its card
111
  metadata_save("model_card.md", {"tags": ["reinforcement-learning", env_id]})
112
+ dummy_input = torch.randn(env.observation_space.shape).unsqueeze(0) # dummy batch of observations
113
  agent = torch.jit.trace(agent.eval(), dummy_input)
114
+ agent = torch.jit.freeze(agent) # required for the model not to depend on the training library
 
115
  torch.jit.save(agent, "agent.pt")
116
 
117
  # Upload model and card to the 🤗 Hub