|
@@ -36,7 +36,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
|
|
|
average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
- initialize_optimizer: bool = True, **kwargs):
|
|
|
+ initialize_optimizer: bool = True, compression_type=None,**kwargs):
|
|
|
|
|
|
self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
|
|
|
self.opt_statistics = tuple(average_opt_statistics)
|