Vectorized environment support for off policy algorithms (SAC, TD3, D… · yonkshi/stable-baselines3@1579713 (original) (raw)

`@@ -89,7 +89,7 @@ def init(

`

89

89

`tensorboard_log: Optional[str] = None,

`

90

90

`verbose: int = 0,

`

91

91

`device: Union[th.device, str] = "auto",

`

92

``

`-

support_multi_env: bool = False,

`

``

92

`+

support_multi_env: bool = True,

`

93

93

`create_eval_env: bool = False,

`

94

94

`monitor_wrapper: bool = True,

`

95

95

`seed: Optional[int] = None,

`

`@@ -98,7 +98,6 @@ def init(

`

98

98

`use_sde_at_warmup: bool = False,

`

99

99

`sde_support: bool = True,

`

100

100

` ):

`

101

``

-

102

101

`super(OffPolicyAlgorithm, self).init(

`

103

102

`policy=policy,

`

104

103

`env=env,

`

`@@ -125,6 +124,7 @@ def init(

`

125

124

`self.n_episodes_rollout = n_episodes_rollout

`

126

125

`self.action_noise = action_noise

`

127

126

`self.optimize_memory_usage = optimize_memory_usage

`

``

127

`+

self.n_envs = env.num_envs

`

128

128

``

129

129

`if train_freq > 0 and n_episodes_rollout > 0:

`

130

130

`warnings.warn(

`

`@@ -152,6 +152,7 @@ def _setup_model(self) -> None:

`

152

152

`self.observation_space,

`

153

153

`self.action_space,

`

154

154

`self.device,

`

``

155

`+

self.n_envs,

`

155

156

`optimize_memory_usage=self.optimize_memory_usage,

`

156

157

` )

`

157

158

`self.policy = self.policy_class(

`

`@@ -273,7 +274,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:

`

273

274

`raise NotImplementedError()

`

274

275

``

275

276

`def _sample_action(

`

276

``

`-

self, learning_starts: int, action_noise: Optional[ActionNoise] = None

`

``

277

`+

self, learning_starts: int, num_envs=1, action_noise: Optional[ActionNoise] = None

`

277

278

` ) -> Tuple[np.ndarray, np.ndarray]:

`

278

279

`"""

`

279

280

` Sample an action according to the exploration policy.

`

`@@ -292,7 +293,7 @@ def _sample_action(

`

292

293

`# Select action randomly or according to policy

`

293

294

`if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):

`

294

295

`# Warmup phase

`

295

``

`-

unscaled_action = np.array([self.action_space.sample()])

`

``

296

`+

unscaled_action = np.array([ self.action_space.sample() for i in range(num_envs) ])

`

296

297

`else:

`

297

298

`# Note: when using continuous actions,

`

298

299

`# we assume that the policy uses tanh to scale the action

`

`@@ -377,10 +378,10 @@ def collect_rollouts(

`

377

378

`total_steps, total_episodes = 0, 0

`

378

379

``

379

380

`assert isinstance(env, VecEnv), "You must pass a VecEnv"

`

380

``

`-

assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"

`

``

381

`+

assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"

`

381

382

``

382

383

`if self.use_sde:

`

383

``

`-

self.actor.reset_noise()

`

``

384

`+

self.actor.reset_noise(self.n_envs)

`

384

385

``

385

386

`callback.on_rollout_start()

`

386

387

`continue_training = True

`

`@@ -393,13 +394,14 @@ def collect_rollouts(

`

393

394

``

394

395

`if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:

`

395

396

`# Sample a new noise matrix

`

396

``

`-

self.actor.reset_noise()

`

``

397

`+

self.actor.reset_noise(self.n_envs)

`

397

398

``

398

399

`# Select action randomly or according to policy

`

399

``

`-

action, buffer_action = self._sample_action(learning_starts, action_noise)

`

400

``

-

``

400

`+

action, buffer_action = self._sample_action(learning_starts, self.n_envs, action_noise)

`

401

401

`# Rescale and perform action

`

``

402

+

402

403

`new_obs, reward, done, infos = env.step(action)

`

``

404

`+

done = np.all(done) # done only when all threads are done

`

403

405

``

404

406

`# Give access to local variables

`

405

407

`callback.update_locals(locals())

`

`@@ -429,9 +431,9 @@ def collect_rollouts(

`

429

431

`if self._vec_normalize_env is not None:

`

430

432

`self.last_original_obs = new_obs

`

431

433

``

432

``

`-

self.num_timesteps += 1

`

433

``

`-

episode_timesteps += 1

`

434

``

`-

total_steps += 1

`

``

434

`+

self.num_timesteps += self.n_envs

`

``

435

`+

episode_timesteps += self.n_envs

`

``

436

`+

total_steps += self.n_envs

`

435

437

`self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

`

436

438

``

437

439

`# For DQN, check if the target network should be updated

`

`@@ -444,6 +446,7 @@ def collect_rollouts(

`

444

446

`break

`

445

447

``

446

448

`if done:

`

``

449

`+

print('Episode Complete', self._episode_num )

`

447

450

`total_episodes += 1

`

448

451

`self._episode_num += 1

`

449

452

`episode_rewards.append(episode_reward)

`