Browse Source

Fix the remaining tests for py37 (#166)

* DecentralizedAverager is now compatible with python37's acyncio exception
    * the problem was: grpc.aio with python37 raised concurrent.futures.CancelledError in some cases;
    * we relied on isinstance(asyncio.CancelledError, Exception) == False
    * but isinstance(concurrent.futures.CancelledError, Exception) == True
*  DecentralizedAverager now shuts down if dereferenced in the main process
    * though it won't shutdown if dereferenced in forks for obvious reasons
* HIVEMIND_THREADS now actually works
* test_averaging now shuts down dht and averager instances to avoid leaking processes

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 years ago
parent
commit
690c9dc32b

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.9.3'
+__version__ = '0.9.4'

+ 41 - 19
hivemind/client/averaging/__init__.py

@@ -6,6 +6,8 @@ import asyncio
 import contextlib
 import contextlib
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
+import threading
+import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 
@@ -123,10 +125,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._averager_endpoint: Optional[Endpoint] = None
         self._averager_endpoint: Optional[Endpoint] = None
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
-
+        # note: we create a background thread weakref and with daemon=True to ensure garbage collection
+        background_fetcher = threading.Thread(daemon=True, target=_background_thread_fetch_current_state,
+                                              args=[self.pipe, weakref.WeakMethod(self.get_current_state)])
+        background_fetcher.start()
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
-            hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
 
 
     @property
     @property
     def port(self) -> Optional[Port]:
     def port(self) -> Optional[Port]:
@@ -183,10 +187,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Shut down the averager process """
         """ Shut down the averager process """
         # TODO notify peers before terminating
         # TODO notify peers before terminating
         if self.is_alive():
         if self.is_alive():
+            self._pipe.send(('_SHUTDOWN', None))
             self.terminate()
             self.terminate()
         else:
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
 
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()
+
     def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
     def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
         """
@@ -331,23 +340,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         return await future
         return await future
 
 
-    def _background_thread_fetch_current_state_if_asked(self):
-        """ Executed in the host process as a background thread. """
-        while True:
-            trigger, future = self.pipe.recv()
-            assert trigger == '_TRIGGER_GET_CURRENT_STATE'
-            try:
-                state_metadata, state_tensors = self.get_current_state()
-                # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-                assert isinstance(state_metadata, bytes)
-                state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
-                                      for tensor in state_tensors)
-                future.set_result((state_metadata, state_tensors))
-            except BaseException as e:
-                future.set_exception(e)
-                logger.warning(e)
-                continue
-
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
         """
         """
         Try to download the latest optimizer state one of the existing peer.
         Try to download the latest optimizer state one of the existing peer.
@@ -439,3 +431,33 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 def is_power_of_two(n):
 def is_power_of_two(n):
     """ Check whether n is a power of 2 """
     """ Check whether n is a power of 2 """
     return (n != 0) and (n & (n - 1) == 0)
     return (n != 0) and (n & (n - 1) == 0)
+
+
+def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod):
+    """
+    Executed in the host process as a background thread. Fetches the averager state when asked by peers.
+    :param pipe: DecentralizedAverager's control pipe (from host process side)
+    :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
+    """
+    while True:
+        trigger, future = pipe.recv()
+        if trigger == '_SHUTDOWN':
+            break
+
+        assert trigger == '_TRIGGER_GET_CURRENT_STATE'
+        try:
+            get_current_state = get_current_state_ref()
+            if get_current_state is None:
+                break
+            state_metadata, state_tensors = get_current_state()
+            del get_current_state
+
+            assert isinstance(state_metadata, bytes)
+            state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
+                                  for tensor in state_tensors)
+            # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
+            future.set_result((state_metadata, state_tensors))
+        except BaseException as e:
+            future.set_exception(e)
+            logger.warning(e)
+            continue

+ 9 - 2
hivemind/client/averaging/matchmaking.py

@@ -7,6 +7,7 @@ import random
 from dataclasses import asdict
 from dataclasses import asdict
 from math import isfinite
 from math import isfinite
 from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
 from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
+import concurrent.futures
 import asyncio
 import asyncio
 
 
 import grpc
 import grpc
@@ -142,6 +143,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         elif len(self.current_followers) > 0:
                         elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
                             await self.leader_disband_group()
                         continue
                         continue
+                except (concurrent.futures.CancelledError, asyncio.CancelledError):
+                    break  # note: this is a compatibility layer for python3.7
                 except Exception as e:
                 except Exception as e:
                     if not self.assembled_group.done():
                     if not self.assembled_group.done():
                         self.assembled_group.set_exception(e)
                         self.assembled_group.set_exception(e)
@@ -256,7 +259,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
                 gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
                 gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
-
+        except (concurrent.futures.CancelledError, asyncio.CancelledError):
+            return  # note: this is a compatibility layer for python3.7
         except Exception as e:
         except Exception as e:
             logger.exception(e)
             logger.exception(e)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
@@ -445,6 +449,8 @@ class PotentialLeaders:
                     {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
                     {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
                     timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
                     timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
                 self.update_triggered.clear()
                 self.update_triggered.clear()
+        except (concurrent.futures.CancelledError, asyncio.CancelledError):
+            return  # note: this is a compatibility layer for python3.7
         except Exception as e:
         except Exception as e:
             logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             raise
             raise
@@ -463,7 +469,8 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
                         await key_manager.update_key_on_not_enough_peers()
-
+            except (concurrent.futures.CancelledError, asyncio.CancelledError):
+                pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             finally:
             finally:

+ 1 - 1
hivemind/utils/threading.py

@@ -12,7 +12,7 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     if os.getpid() != EXECUTOR_PID:
     if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=float(os.environ.get("HIVEMIND_THREADS", 'inf')))
         EXECUTOR_PID = os.getpid()
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 
 

+ 20 - 0
tests/test_averaging.py

@@ -70,6 +70,10 @@ def test_allreduce_once():
             for ref, our in zip(reference, averaged_tensors):
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 
 def compute_mean_std(averagers, unbiased=True):
 def compute_mean_std(averagers, unbiased=True):
     results = []
     results = []
@@ -108,6 +112,10 @@ def test_allreduce_grid():
         else:
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allgather():
 def test_allgather():
@@ -133,6 +141,10 @@ def test_allgather():
         for endpoint in gathered:
         for endpoint in gathered:
             assert gathered[endpoint] == reference_metadata[endpoint]
             assert gathered[endpoint] == reference_metadata[endpoint]
 
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -249,6 +261,10 @@ def test_too_few_peers():
     for future in step_futures:
     for future in step_futures:
         assert len(future.result()) == 2
         assert len(future.result()) == 2
 
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_overcrowded():
 def test_overcrowded():
@@ -262,6 +278,10 @@ def test_overcrowded():
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_load_state_from_peers():
 def test_load_state_from_peers():