Ver código fonte

Fix the remaining tests for py37 (#166)

* DecentralizedAverager is now compatible with python37's acyncio exception
    * the problem was: grpc.aio with python37 raised concurrent.futures.CancelledError in some cases;
    * we relied on isinstance(asyncio.CancelledError, Exception) == False
    * but isinstance(concurrent.futures.CancelledError, Exception) == True
*  DecentralizedAverager now shuts down if dereferenced in the main process
    * though it won't shutdown if dereferenced in forks for obvious reasons
* HIVEMIND_THREADS now actually works
* test_averaging now shuts down dht and averager instances to avoid leaking processes

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 anos atrás
pai
commit
690c9dc32b

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.9.3'
+__version__ = '0.9.4'

+ 41 - 19
hivemind/client/averaging/__init__.py

@@ -6,6 +6,8 @@ import asyncio
 import contextlib
 import ctypes
 import multiprocessing as mp
+import threading
+import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
@@ -123,10 +125,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._averager_endpoint: Optional[Endpoint] = None
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
-
+        # note: we create a background thread weakref and with daemon=True to ensure garbage collection
+        background_fetcher = threading.Thread(daemon=True, target=_background_thread_fetch_current_state,
+                                              args=[self.pipe, weakref.WeakMethod(self.get_current_state)])
+        background_fetcher.start()
         if start:
             self.run_in_background(await_ready=True)
-            hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
 
     @property
     def port(self) -> Optional[Port]:
@@ -183,10 +187,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Shut down the averager process """
         # TODO notify peers before terminating
         if self.is_alive():
+            self._pipe.send(('_SHUTDOWN', None))
             self.terminate()
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()
+
     def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
@@ -331,23 +340,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         return await future
 
-    def _background_thread_fetch_current_state_if_asked(self):
-        """ Executed in the host process as a background thread. """
-        while True:
-            trigger, future = self.pipe.recv()
-            assert trigger == '_TRIGGER_GET_CURRENT_STATE'
-            try:
-                state_metadata, state_tensors = self.get_current_state()
-                # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-                assert isinstance(state_metadata, bytes)
-                state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
-                                      for tensor in state_tensors)
-                future.set_result((state_metadata, state_tensors))
-            except BaseException as e:
-                future.set_exception(e)
-                logger.warning(e)
-                continue
-
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
@@ -439,3 +431,33 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 def is_power_of_two(n):
     """ Check whether n is a power of 2 """
     return (n != 0) and (n & (n - 1) == 0)
+
+
+def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod):
+    """
+    Executed in the host process as a background thread. Fetches the averager state when asked by peers.
+    :param pipe: DecentralizedAverager's control pipe (from host process side)
+    :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
+    """
+    while True:
+        trigger, future = pipe.recv()
+        if trigger == '_SHUTDOWN':
+            break
+
+        assert trigger == '_TRIGGER_GET_CURRENT_STATE'
+        try:
+            get_current_state = get_current_state_ref()
+            if get_current_state is None:
+                break
+            state_metadata, state_tensors = get_current_state()
+            del get_current_state
+
+            assert isinstance(state_metadata, bytes)
+            state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
+                                  for tensor in state_tensors)
+            # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
+            future.set_result((state_metadata, state_tensors))
+        except BaseException as e:
+            future.set_exception(e)
+            logger.warning(e)
+            continue

+ 9 - 2
hivemind/client/averaging/matchmaking.py

@@ -7,6 +7,7 @@ import random
 from dataclasses import asdict
 from math import isfinite
 from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
+import concurrent.futures
 import asyncio
 
 import grpc
@@ -142,6 +143,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
                         continue
+                except (concurrent.futures.CancelledError, asyncio.CancelledError):
+                    break  # note: this is a compatibility layer for python3.7
                 except Exception as e:
                     if not self.assembled_group.done():
                         self.assembled_group.set_exception(e)
@@ -256,7 +259,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
                 gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
-
+        except (concurrent.futures.CancelledError, asyncio.CancelledError):
+            return  # note: this is a compatibility layer for python3.7
         except Exception as e:
             logger.exception(e)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
@@ -445,6 +449,8 @@ class PotentialLeaders:
                     {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
                     timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
                 self.update_triggered.clear()
+        except (concurrent.futures.CancelledError, asyncio.CancelledError):
+            return  # note: this is a compatibility layer for python3.7
         except Exception as e:
             logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             raise
@@ -463,7 +469,8 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
-
+            except (concurrent.futures.CancelledError, asyncio.CancelledError):
+                pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             finally:

+ 1 - 1
hivemind/utils/threading.py

@@ -12,7 +12,7 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=float(os.environ.get("HIVEMIND_THREADS", 'inf')))
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 

+ 20 - 0
tests/test_averaging.py

@@ -70,6 +70,10 @@ def test_allreduce_once():
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 def compute_mean_std(averagers, unbiased=True):
     results = []
@@ -108,6 +112,10 @@ def test_allreduce_grid():
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 @pytest.mark.forked
 def test_allgather():
@@ -133,6 +141,10 @@ def test_allgather():
         for endpoint in gathered:
             assert gathered[endpoint] == reference_metadata[endpoint]
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
@@ -249,6 +261,10 @@ def test_too_few_peers():
     for future in step_futures:
         assert len(future.result()) == 2
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 @pytest.mark.forked
 def test_overcrowded():
@@ -262,6 +278,10 @@ def test_overcrowded():
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
+    for averager in averagers:
+        averager.shutdown()
+    dht.shutdown()
+
 
 @pytest.mark.forked
 def test_load_state_from_peers():