Pārlūkot izejas kodu

Prevent DecentralizedSGD from accidentally skipping a fraction of training batches (#218)

ploshkin 4 gadi atpakaļ
vecāks
revīzija
9d2a40714c
1 mainītis faili ar 6 papildinājumiem un 2 dzēšanām
  1. 6 2
      hivemind/client/optim/simple.py

+ 6 - 2
hivemind/client/optim/simple.py

@@ -86,6 +86,8 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
                 for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
 
+                old_local_tensors = tuple(local_tensor.cpu().detach().clone() for local_tensor in local_tensors)
+
             try:
                 if verbose:
                     logger.info(f"Starting a new averaging round with current parameters.")
@@ -93,8 +95,10 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
 
                 if group_info is not None:
                     with lock_parameters, averager.get_tensors() as averaged_tensors:
-                        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype)
+                        for local_tensor, old_local_tensor, averaged_tensor in zip(
+                            local_tensors, old_local_tensors, averaged_tensors
+                        ):
+                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype) - old_local_tensor
                     if verbose:
                         logger.info(f"Finished averaging round in with {len(group_info)} peers.")
                 else: