Michael Diskin 4 yıl önce
ebeveyn
işleme
d96950bb09
1 değiştirilmiş dosya ile 6 ekleme ve 1 silme
  1. 6 1
      hivemind/averaging/training.py

+ 6 - 1
hivemind/averaging/training.py

@@ -68,17 +68,20 @@ class TrainingAverager(DecentralizedAverager):
         """
         if not wait:
             return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
+        logger.debug(f"AAAAA 1")
 
         # if data_lock is supplied, tensors might change during averaging, so we need to copy them
         use_old_local_tensors = data_lock is not None
         if data_lock is None:
             data_lock = nullcontext()
-
+        logger.debug(f"AAAAA 2")
         local_tensors = list(self.local_tensors())
+        logger.debug(f"AAAAA 3")
         with self.lock_averager_step, torch.no_grad():
             # fill averager's tensors with current local tensors
             self.pending_updates_done.clear()
             with data_lock, self.get_tensors() as averaged_tensors:
+                logger.debug(f"AAAAA 4")
                 if use_old_local_tensors:
                     old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
                 assert len(local_tensors) == len(
@@ -86,7 +89,9 @@ class TrainingAverager(DecentralizedAverager):
                 ), "The number of optimized parameters should not change."
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
+            logger.debug(f"AAAAA 5")
             self.pending_updates_done.set()
+        logger.debug(f"AAAAA 84")
 
             # find a group and hopefully average tensors with peers, use batch sizes as weights
             gathered = super().step(**kwargs)