|
@@ -366,6 +366,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
self.pending_update = self.step_executor.submit(
|
|
|
self._do,
|
|
|
+ wait_for_trigger,
|
|
|
optimizer_step,
|
|
|
zero_grad,
|
|
|
averaging_round,
|
|
@@ -397,6 +398,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
def _do(
|
|
|
self,
|
|
|
+ wait_for_trigger: Optional[Callable[[], NoneType]],
|
|
|
optimizer_step: bool,
|
|
|
zero_grad: bool,
|
|
|
averaging_round: bool,
|
|
@@ -408,6 +410,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
This method is meant to be called in the background executor.
|
|
|
"""
|
|
|
try:
|
|
|
+ if wait_for_trigger is not None:
|
|
|
+ wait_for_trigger()
|
|
|
if optimizer_step:
|
|
|
with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
|
|
|
logger.log(self.status_loglevel, f"Running optimizer step")
|