Vectorized environment support for off policy algorithms (SAC, TD3, D… · yonkshi/stable-baselines3@1579713 (original) (raw)
`@@ -89,7 +89,7 @@ def init(
`
89
89
`tensorboard_log: Optional[str] = None,
`
90
90
`verbose: int = 0,
`
91
91
`device: Union[th.device, str] = "auto",
`
92
``
`-
support_multi_env: bool = False,
`
``
92
`+
support_multi_env: bool = True,
`
93
93
`create_eval_env: bool = False,
`
94
94
`monitor_wrapper: bool = True,
`
95
95
`seed: Optional[int] = None,
`
`@@ -98,7 +98,6 @@ def init(
`
98
98
`use_sde_at_warmup: bool = False,
`
99
99
`sde_support: bool = True,
`
100
100
` ):
`
101
``
-
102
101
`super(OffPolicyAlgorithm, self).init(
`
103
102
`policy=policy,
`
104
103
`env=env,
`
`@@ -125,6 +124,7 @@ def init(
`
125
124
`self.n_episodes_rollout = n_episodes_rollout
`
126
125
`self.action_noise = action_noise
`
127
126
`self.optimize_memory_usage = optimize_memory_usage
`
``
127
`+
self.n_envs = env.num_envs
`
128
128
``
129
129
`if train_freq > 0 and n_episodes_rollout > 0:
`
130
130
`warnings.warn(
`
`@@ -152,6 +152,7 @@ def _setup_model(self) -> None:
`
152
152
`self.observation_space,
`
153
153
`self.action_space,
`
154
154
`self.device,
`
``
155
`+
self.n_envs,
`
155
156
`optimize_memory_usage=self.optimize_memory_usage,
`
156
157
` )
`
157
158
`self.policy = self.policy_class(
`
`@@ -273,7 +274,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
`
273
274
`raise NotImplementedError()
`
274
275
``
275
276
`def _sample_action(
`
276
``
`-
self, learning_starts: int, action_noise: Optional[ActionNoise] = None
`
``
277
`+
self, learning_starts: int, num_envs=1, action_noise: Optional[ActionNoise] = None
`
277
278
` ) -> Tuple[np.ndarray, np.ndarray]:
`
278
279
`"""
`
279
280
` Sample an action according to the exploration policy.
`
`@@ -292,7 +293,7 @@ def _sample_action(
`
292
293
`# Select action randomly or according to policy
`
293
294
`if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
`
294
295
`# Warmup phase
`
295
``
`-
unscaled_action = np.array([self.action_space.sample()])
`
``
296
`+
unscaled_action = np.array([ self.action_space.sample() for i in range(num_envs) ])
`
296
297
`else:
`
297
298
`# Note: when using continuous actions,
`
298
299
`# we assume that the policy uses tanh to scale the action
`
`@@ -377,10 +378,10 @@ def collect_rollouts(
`
377
378
`total_steps, total_episodes = 0, 0
`
378
379
``
379
380
`assert isinstance(env, VecEnv), "You must pass a VecEnv"
`
380
``
`-
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
`
``
381
`+
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
`
381
382
``
382
383
`if self.use_sde:
`
383
``
`-
self.actor.reset_noise()
`
``
384
`+
self.actor.reset_noise(self.n_envs)
`
384
385
``
385
386
`callback.on_rollout_start()
`
386
387
`continue_training = True
`
`@@ -393,13 +394,14 @@ def collect_rollouts(
`
393
394
``
394
395
`if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
`
395
396
`# Sample a new noise matrix
`
396
``
`-
self.actor.reset_noise()
`
``
397
`+
self.actor.reset_noise(self.n_envs)
`
397
398
``
398
399
`# Select action randomly or according to policy
`
399
``
`-
action, buffer_action = self._sample_action(learning_starts, action_noise)
`
400
``
-
``
400
`+
action, buffer_action = self._sample_action(learning_starts, self.n_envs, action_noise)
`
401
401
`# Rescale and perform action
`
``
402
+
402
403
`new_obs, reward, done, infos = env.step(action)
`
``
404
`+
done = np.all(done) # done only when all threads are done
`
403
405
``
404
406
`# Give access to local variables
`
405
407
`callback.update_locals(locals())
`
`@@ -429,9 +431,9 @@ def collect_rollouts(
`
429
431
`if self._vec_normalize_env is not None:
`
430
432
`self.last_original_obs = new_obs
`
431
433
``
432
``
`-
self.num_timesteps += 1
`
433
``
`-
episode_timesteps += 1
`
434
``
`-
total_steps += 1
`
``
434
`+
self.num_timesteps += self.n_envs
`
``
435
`+
episode_timesteps += self.n_envs
`
``
436
`+
total_steps += self.n_envs
`
435
437
`self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
`
436
438
``
437
439
`# For DQN, check if the target network should be updated
`
`@@ -444,6 +446,7 @@ def collect_rollouts(
`
444
446
`break
`
445
447
``
446
448
`if done:
`
``
449
`+
print('Episode Complete', self._episode_num )
`
447
450
`total_episodes += 1
`
448
451
`self._episode_num += 1
`
449
452
`episode_rewards.append(episode_reward)
`