|
@@ -86,6 +86,8 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
|
|
|
for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
|
averaged_tensor[...] = local_tensor.cpu().float()
|
|
|
|
|
|
+ old_local_tensors = tuple(local_tensor.cpu().detach().clone() for local_tensor in local_tensors)
|
|
|
+
|
|
|
try:
|
|
|
if verbose:
|
|
|
logger.info(f"Starting a new averaging round with current parameters.")
|
|
@@ -93,8 +95,10 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
if group_info is not None:
|
|
|
with lock_parameters, averager.get_tensors() as averaged_tensors:
|
|
|
- for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
|
- local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype)
|
|
|
+ for local_tensor, old_local_tensor, averaged_tensor in zip(
|
|
|
+ local_tensors, old_local_tensors, averaged_tensors
|
|
|
+ ):
|
|
|
+ local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype) - old_local_tensor
|
|
|
if verbose:
|
|
|
logger.info(f"Finished averaging round in with {len(group_info)} peers.")
|
|
|
else:
|