Kaynağa Gözat

move applying of averaged params to main thread

Anton Sinitsin 4 yıl önce
ebeveyn
işleme
1572fd6ab0

+ 7 - 4
hivemind/client/averaging/training.py

@@ -1,6 +1,6 @@
 """ An extension of averager that supports common optimization use cases. """
 from itertools import chain
-from threading import Lock
+from threading import Lock, Event
 from typing import Sequence, Dict, Iterator, Optional
 from contextlib import nullcontext
 
@@ -41,6 +41,8 @@ class TrainingAverager(DecentralizedAverager):
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.lock_averager_step = Lock()
+        self.averaging_ready_event = Event()
+        self.update = None
         self.scheduler = scheduler
         if initialize_optimizer:
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
@@ -79,7 +81,7 @@ class TrainingAverager(DecentralizedAverager):
             gathered = super().step(**kwargs)
             if gathered is not None:
                 # load averaged tensors back into model
-                with data_lock, self.get_tensors() as averaged_tensors:
+                with self.get_tensors() as averaged_tensors:
                     if len(averaged_tensors) != len(local_tensors):
                         raise RuntimeError("The number of optimized parameters should not change.")
 
@@ -88,15 +90,16 @@ class TrainingAverager(DecentralizedAverager):
                         # losing local updates that might have occurred during averaging
                         for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
                                                                                    old_local_tensors):
-                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
+                            self.update = averaged_tensor.to(dtype=local_tensor.dtype,
                                                                     device=local_tensor.device) - \
                                                  old_local_tensor.to(dtype=local_tensor.dtype,
                                                                      device=local_tensor.device)
                     else:
                         for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+                             self.update = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
 
             self.local_step += 1
+            self.averaging_ready_event.set()
             return gathered
 
     def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:

+ 8 - 0
hivemind/optim/averaged.py

@@ -106,6 +106,14 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     def step(self, *args, **kwargs):
         self._sync_if_needed()
 
+        if self.averager.averaging_ready_event.is_set():
+            with self.averager.lock_averager_step, torch.no_grad():
+                local_tensors = list(self.averager.local_tensors())
+                for averaged_tensor, local_tensor in zip(self.averager.update, local_tensors):
+                    local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+                self.averager.update = None
+                self.averager.averaging_ready_event.clear()
+
         with self.lock_scheduler_params:
             if self.local_epoch < self.training_state.max_epoch:
                 self.local_step = 0

+ 2 - 1
tests/test_averaging.py

@@ -463,7 +463,8 @@ def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int =
         (x2 - target).pow(2).sum().backward()
         sgd1.step()
         sgd2.step()
-        time.sleep(time_to_wait)
+        sgd1.averager.averaging_ready_event.wait()
+        sgd2.averager.averaging_ready_event.wait()
     assert sgd1.local_epoch == sgd2.local_epoch
     assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])