|
@@ -1,6 +1,6 @@
|
|
""" An extension of averager that supports common optimization use cases. """
|
|
""" An extension of averager that supports common optimization use cases. """
|
|
from itertools import chain
|
|
from itertools import chain
|
|
-from threading import Lock
|
|
|
|
|
|
+from threading import Lock, Event
|
|
from typing import Sequence, Dict, Iterator, Optional
|
|
from typing import Sequence, Dict, Iterator, Optional
|
|
from contextlib import nullcontext
|
|
from contextlib import nullcontext
|
|
|
|
|
|
@@ -41,6 +41,8 @@ class TrainingAverager(DecentralizedAverager):
|
|
self.opt_statistics = tuple(average_opt_statistics)
|
|
self.opt_statistics = tuple(average_opt_statistics)
|
|
self.average_parameters, self.average_gradients = average_parameters, average_gradients
|
|
self.average_parameters, self.average_gradients = average_parameters, average_gradients
|
|
self.lock_averager_step = Lock()
|
|
self.lock_averager_step = Lock()
|
|
|
|
+ self.averaging_ready_event = Event()
|
|
|
|
+ self.update = None
|
|
self.scheduler = scheduler
|
|
self.scheduler = scheduler
|
|
if initialize_optimizer:
|
|
if initialize_optimizer:
|
|
initialize_optimizer_state(opt) # note: this will run one optimizer step!
|
|
initialize_optimizer_state(opt) # note: this will run one optimizer step!
|
|
@@ -79,7 +81,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
gathered = super().step(**kwargs)
|
|
gathered = super().step(**kwargs)
|
|
if gathered is not None:
|
|
if gathered is not None:
|
|
# load averaged tensors back into model
|
|
# load averaged tensors back into model
|
|
- with data_lock, self.get_tensors() as averaged_tensors:
|
|
|
|
|
|
+ with self.get_tensors() as averaged_tensors:
|
|
if len(averaged_tensors) != len(local_tensors):
|
|
if len(averaged_tensors) != len(local_tensors):
|
|
raise RuntimeError("The number of optimized parameters should not change.")
|
|
raise RuntimeError("The number of optimized parameters should not change.")
|
|
|
|
|
|
@@ -88,15 +90,16 @@ class TrainingAverager(DecentralizedAverager):
|
|
# losing local updates that might have occurred during averaging
|
|
# losing local updates that might have occurred during averaging
|
|
for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
|
|
for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
|
|
old_local_tensors):
|
|
old_local_tensors):
|
|
- local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
|
|
|
|
|
|
+ self.update = averaged_tensor.to(dtype=local_tensor.dtype,
|
|
device=local_tensor.device) - \
|
|
device=local_tensor.device) - \
|
|
old_local_tensor.to(dtype=local_tensor.dtype,
|
|
old_local_tensor.to(dtype=local_tensor.dtype,
|
|
device=local_tensor.device)
|
|
device=local_tensor.device)
|
|
else:
|
|
else:
|
|
for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
|
|
for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
|
|
- local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
|
|
|
|
|
|
+ self.update = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
|
|
|
|
|
|
self.local_step += 1
|
|
self.local_step += 1
|
|
|
|
+ self.averaging_ready_event.set()
|
|
return gathered
|
|
return gathered
|
|
|
|
|
|
def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
|
|
def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
|