@@ -80,8 +80,12 @@ def __init__( |
|
|
80 |
80 |
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
81 |
81 |
self._rank = self.accelerator.local_process_index |
82 |
82 |
self._world_size = self.accelerator.num_processes |
|
83 |
+else: |
|
84 |
+self.model.to(self._device) |
|
85 |
+self._rank = 0 |
|
86 |
+self._word_size = 1 |
83 |
87 |
|
84 |
|
-if accelerator.num_processes > 1: |
|
88 |
+'''if accelerator.num_processes > 1: |
85 |
89 |
assert accelerator.distributed_type in [ |
86 |
90 |
DistributedType.FSDP, |
87 |
91 |
DistributedType.MULTI_GPU, |
@@ -94,11 +98,7 @@ def __init__( |
|
|
94 |
98 |
if self.accelerator.is_local_main_process: |
95 |
99 |
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
96 |
100 |
self._rank = self.accelerator.local_process_index |
97 |
|
-self._world_size = self.accelerator.num_processes |
98 |
|
-else: |
99 |
|
-self.model.to(self._device) |
100 |
|
-self._rank = 0 |
101 |
|
-self._word_size = 1 |
|
101 |
+ self._world_size = self.accelerator.num_processes''' |
102 |
102 |
|
103 |
103 |
@property |
104 |
104 |
def config(self): |