|
|
@@ -132,7 +132,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
for param in param_group['params'])
|
|
|
extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
|
|
|
optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
|
|
|
- scheduler_state = self.scheduler.state_dict() if self.scheduler else None
|
|
|
+ scheduler_state = self.scheduler.state_dict() if self.scheduler is not None else None
|
|
|
|
|
|
metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
|
|
|
optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)
|