|
@@ -68,17 +68,20 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
"""
|
|
|
if not wait:
|
|
|
return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
|
|
|
+ logger.debug(f"AAAAA 1")
|
|
|
|
|
|
# if data_lock is supplied, tensors might change during averaging, so we need to copy them
|
|
|
use_old_local_tensors = data_lock is not None
|
|
|
if data_lock is None:
|
|
|
data_lock = nullcontext()
|
|
|
-
|
|
|
+ logger.debug(f"AAAAA 2")
|
|
|
local_tensors = list(self.local_tensors())
|
|
|
+ logger.debug(f"AAAAA 3")
|
|
|
with self.lock_averager_step, torch.no_grad():
|
|
|
# fill averager's tensors with current local tensors
|
|
|
self.pending_updates_done.clear()
|
|
|
with data_lock, self.get_tensors() as averaged_tensors:
|
|
|
+ logger.debug(f"AAAAA 4")
|
|
|
if use_old_local_tensors:
|
|
|
old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
|
|
|
assert len(local_tensors) == len(
|
|
@@ -86,7 +89,9 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
), "The number of optimized parameters should not change."
|
|
|
for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
|
|
|
averaged_tensor[...] = local_tensor.cpu().float()
|
|
|
+ logger.debug(f"AAAAA 5")
|
|
|
self.pending_updates_done.set()
|
|
|
+ logger.debug(f"AAAAA 84")
|
|
|
|
|
|
# find a group and hopefully average tensors with peers, use batch sizes as weights
|
|
|
gathered = super().step(**kwargs)
|