|
@@ -7,6 +7,7 @@ import contextlib
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import os
|
|
|
+import random
|
|
|
import threading
|
|
|
import weakref
|
|
|
from dataclasses import asdict
|
|
@@ -164,7 +165,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
self._averaged_tensors = tuple(averaged_tensors)
|
|
|
self.lock_averaged_tensors = mp.Lock()
|
|
|
- self.last_updated: DHTExpiration = -float("inf")
|
|
|
for tensor in self._averaged_tensors:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|
|
@@ -193,6 +193,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
|
|
|
|
self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
+ self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
|
|
|
+
|
|
|
if allow_state_sharing is None:
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
@@ -221,7 +223,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if value and self.client_mode:
|
|
|
raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
|
|
|
else:
|
|
|
- self._allow_state_sharing.value = value
|
|
|
+ old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
|
|
|
+ if value != old_value:
|
|
|
+ self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def state_sharing_priority(self) -> float:
|
|
|
+ """Others will preferentially downloading state from peers with highest priority."""
|
|
|
+ return float(self._state_sharing_priority.value)
|
|
|
+
|
|
|
+ @state_sharing_priority.setter
|
|
|
+ def state_sharing_priority(self, value: float):
|
|
|
+ if value and self.client_mode:
|
|
|
+ raise ValueError("State sharing priority is unused: averager in client mode cannot share its state.")
|
|
|
+ else:
|
|
|
+ old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
|
|
|
+ if self.allow_state_sharing and value != old_value:
|
|
|
+ self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
|
|
|
+
|
|
|
+ async def _trigger_declare_load_state(self):
|
|
|
+ # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
|
|
|
+ self._state_updated.set()
|
|
|
|
|
|
@property
|
|
|
def peer_id(self) -> PeerID:
|
|
@@ -490,7 +512,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
# all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self.last_updated = get_dht_time()
|
|
|
self._state_updated.set()
|
|
|
|
|
|
else:
|
|
@@ -550,24 +571,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
async def _declare_for_download_periodically(self):
|
|
|
download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
|
|
|
+ sharing_was_allowed = self.allow_state_sharing
|
|
|
while True:
|
|
|
- if self.allow_state_sharing:
|
|
|
- self._state_updated.clear()
|
|
|
- expiration_time = get_dht_time() + self.declare_state_period
|
|
|
+ expiration_time = get_dht_time() + self.declare_state_period
|
|
|
+ if self.allow_state_sharing or sharing_was_allowed:
|
|
|
+ # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
|
|
|
asyncio.create_task(
|
|
|
asyncio.wait_for(
|
|
|
self.dht.store(
|
|
|
download_key,
|
|
|
subkey=self.peer_id.to_bytes(),
|
|
|
- value=self.last_updated,
|
|
|
+ value=self.state_sharing_priority if self.allow_state_sharing else None,
|
|
|
expiration_time=expiration_time,
|
|
|
return_future=True,
|
|
|
),
|
|
|
- timeout=expiration_time - self.request_timeout,
|
|
|
+ timeout=expiration_time - get_dht_time(),
|
|
|
)
|
|
|
)
|
|
|
+ sharing_was_allowed = self.allow_state_sharing
|
|
|
+
|
|
|
+ # report again either in state_declare_period or after the field was changed by the user
|
|
|
+ self._state_updated.clear()
|
|
|
try:
|
|
|
- await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
|
|
|
+ await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
|
|
|
except asyncio.TimeoutError:
|
|
|
pass
|
|
|
|
|
@@ -632,7 +658,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
key_manager = self._matchmaking.group_key_manager
|
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
|
peer_priority = {
|
|
|
- PeerID(peer_id): float(info.value)
|
|
|
+ PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
|
|
|
for peer_id, info in peer_priority.items()
|
|
|
if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
|
|
|
}
|