Bläddra i källkod

Merge remote-tracking branch 'origin/master' into hivemind_optimizer_thirdtimesthecharm

justheuristic 3 år sedan
förälder
incheckning
4545f3133e
4 ändrade filer med 89 tillägg och 11 borttagningar
  1. 36 10
      hivemind/averaging/averager.py
  2. 4 0
      hivemind/optim/collaborative.py
  3. 4 0
      hivemind/optim/simple.py
  4. 45 1
      tests/test_averaging.py

+ 36 - 10
hivemind/averaging/averager.py

@@ -7,6 +7,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import random
 import threading
 import weakref
 from dataclasses import asdict
@@ -164,7 +165,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
-        self.last_updated: DHTExpiration = -float("inf")
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
@@ -193,6 +193,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
+
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
@@ -221,7 +223,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if value and self.client_mode:
             raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
         else:
-            self._allow_state_sharing.value = value
+            old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
+            if value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    @property
+    def state_sharing_priority(self) -> float:
+        """Others will preferentially downloading state from peers with highest priority."""
+        return float(self._state_sharing_priority.value)
+
+    @state_sharing_priority.setter
+    def state_sharing_priority(self, value: float):
+        if value and self.client_mode:
+            raise ValueError("State sharing priority is unused: averager in client mode cannot share its state.")
+        else:
+            old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
+            if self.allow_state_sharing and value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    async def _trigger_declare_load_state(self):
+        # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
+        self._state_updated.set()
 
     @property
     def peer_id(self) -> PeerID:
@@ -490,7 +512,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                            self.last_updated = get_dht_time()
                             self._state_updated.set()
 
                     else:
@@ -550,24 +571,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     async def _declare_for_download_periodically(self):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
+        sharing_was_allowed = self.allow_state_sharing
         while True:
-            if self.allow_state_sharing:
-                self._state_updated.clear()
-                expiration_time = get_dht_time() + self.declare_state_period
+            expiration_time = get_dht_time() + self.declare_state_period
+            if self.allow_state_sharing or sharing_was_allowed:
+                # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
-                            value=self.last_updated,
+                            value=self.state_sharing_priority if self.allow_state_sharing else None,
                             expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=expiration_time - self.request_timeout,
+                        timeout=expiration_time - get_dht_time(),
                     )
                 )
+                sharing_was_allowed = self.allow_state_sharing
+
+            # report again either in state_declare_period or after the field was changed by the user
+            self._state_updated.clear()
             try:
-                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+                await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
             except asyncio.TimeoutError:
                 pass
 
@@ -632,7 +658,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                PeerID(peer_id): float(info.value)
+                PeerID(peer_id): (float(info.value), random.random())  # using randomness as a tie breaker
                 for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }

+ 4 - 0
hivemind/optim/collaborative.py

@@ -320,10 +320,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.update_scheduler()
+
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                     assert grad_scaler.update()
 
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = self.local_step
+
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
         return group_info

+ 4 - 0
hivemind/optim/simple.py

@@ -86,6 +86,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             if self.local_step % self.averaging_step_period == 0:
                 self.update_event.set()
             self.averager.pending_updates_done.wait()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = get_dht_time()
+
             return loss
         finally:
             self.lock_parameters.acquire()

+ 45 - 1
tests/test_averaging.py

@@ -372,7 +372,6 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
-    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         dht=dht_instances[1],
@@ -381,6 +380,8 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
+    time.sleep(0.5)
+
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 1
@@ -399,7 +400,9 @@ def test_load_state_from_peers():
 
     averager1.allow_state_sharing = False
     assert averager2.load_state_from_peers() is None
+
     averager1.allow_state_sharing = True
+    time.sleep(0.5)
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 3
     assert got_metadata == super_metadata
@@ -408,6 +411,47 @@ def test_load_state_from_peers():
         instance.shutdown()
 
 
+@pytest.mark.forked
+def test_load_state_priority():
+    dht_instances = launch_dht_instances(4)
+
+    averagers = []
+    for i in range(4):
+        averager = hivemind.DecentralizedAverager(
+            [torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
+            dht=dht_instances[i],
+            start=True,
+            prefix="demo-run",
+            target_group_size=2,
+            allow_state_sharing=i != 1,
+        )
+        averager.state_sharing_priority = 5 - abs(2 - i)
+        averagers.append(averager)
+
+    time.sleep(0.5)
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 2
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    averagers[0].state_sharing_priority = 10
+    time.sleep(0.2)
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 0
+
+    averagers[1].allow_state_sharing = False
+    averagers[2].allow_state_sharing = False
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    for averager in averagers:
+        averager.shutdown()
+    for dht in dht_instances:
+        dht.shutdown()
+
+
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)