|
@@ -6,6 +6,8 @@ import asyncio
|
|
import contextlib
|
|
import contextlib
|
|
import ctypes
|
|
import ctypes
|
|
import multiprocessing as mp
|
|
import multiprocessing as mp
|
|
|
|
+import threading
|
|
|
|
+import weakref
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
|
|
|
|
@@ -123,10 +125,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
-
|
|
|
|
|
|
+ # note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
|
|
+ background_fetcher = threading.Thread(daemon=True, target=_background_thread_fetch_current_state,
|
|
|
|
+ args=[self.pipe, weakref.WeakMethod(self.get_current_state)])
|
|
|
|
+ background_fetcher.start()
|
|
if start:
|
|
if start:
|
|
self.run_in_background(await_ready=True)
|
|
self.run_in_background(await_ready=True)
|
|
- hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
|
|
|
|
|
|
|
|
@property
|
|
@property
|
|
def port(self) -> Optional[Port]:
|
|
def port(self) -> Optional[Port]:
|
|
@@ -183,10 +187,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
""" Shut down the averager process """
|
|
""" Shut down the averager process """
|
|
# TODO notify peers before terminating
|
|
# TODO notify peers before terminating
|
|
if self.is_alive():
|
|
if self.is_alive():
|
|
|
|
+ self._pipe.send(('_SHUTDOWN', None))
|
|
self.terminate()
|
|
self.terminate()
|
|
else:
|
|
else:
|
|
logger.warning("DHT shutdown has no effect: the process is not alive")
|
|
logger.warning("DHT shutdown has no effect: the process is not alive")
|
|
|
|
|
|
|
|
+ def __del__(self):
|
|
|
|
+ if self.is_alive():
|
|
|
|
+ self.shutdown()
|
|
|
|
+
|
|
def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
|
|
def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
|
|
wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
|
|
wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
|
|
"""
|
|
"""
|
|
@@ -331,23 +340,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
|
|
self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
|
|
return await future
|
|
return await future
|
|
|
|
|
|
- def _background_thread_fetch_current_state_if_asked(self):
|
|
|
|
- """ Executed in the host process as a background thread. """
|
|
|
|
- while True:
|
|
|
|
- trigger, future = self.pipe.recv()
|
|
|
|
- assert trigger == '_TRIGGER_GET_CURRENT_STATE'
|
|
|
|
- try:
|
|
|
|
- state_metadata, state_tensors = self.get_current_state()
|
|
|
|
- # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
|
|
- assert isinstance(state_metadata, bytes)
|
|
|
|
- state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
|
|
|
|
- for tensor in state_tensors)
|
|
|
|
- future.set_result((state_metadata, state_tensors))
|
|
|
|
- except BaseException as e:
|
|
|
|
- future.set_exception(e)
|
|
|
|
- logger.warning(e)
|
|
|
|
- continue
|
|
|
|
-
|
|
|
|
def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
|
|
def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
|
|
"""
|
|
"""
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
@@ -439,3 +431,33 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
def is_power_of_two(n):
|
|
def is_power_of_two(n):
|
|
""" Check whether n is a power of 2 """
|
|
""" Check whether n is a power of 2 """
|
|
return (n != 0) and (n & (n - 1) == 0)
|
|
return (n != 0) and (n & (n - 1) == 0)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod):
|
|
|
|
+ """
|
|
|
|
+ Executed in the host process as a background thread. Fetches the averager state when asked by peers.
|
|
|
|
+ :param pipe: DecentralizedAverager's control pipe (from host process side)
|
|
|
|
+ :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
|
|
|
|
+ """
|
|
|
|
+ while True:
|
|
|
|
+ trigger, future = pipe.recv()
|
|
|
|
+ if trigger == '_SHUTDOWN':
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ assert trigger == '_TRIGGER_GET_CURRENT_STATE'
|
|
|
|
+ try:
|
|
|
|
+ get_current_state = get_current_state_ref()
|
|
|
|
+ if get_current_state is None:
|
|
|
|
+ break
|
|
|
|
+ state_metadata, state_tensors = get_current_state()
|
|
|
|
+ del get_current_state
|
|
|
|
+
|
|
|
|
+ assert isinstance(state_metadata, bytes)
|
|
|
|
+ state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
|
|
|
|
+ for tensor in state_tensors)
|
|
|
|
+ # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
|
|
+ future.set_result((state_metadata, state_tensors))
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ future.set_exception(e)
|
|
|
|
+ logger.warning(e)
|
|
|
|
+ continue
|