Forráskód Böngészése

Remove pickle.loads in averager (#160)

* Security update: remove pickle.loads in averager
* add py37 to circleci config

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 éve
szülő
commit
ca4b8c5df6

+ 1 - 0
.circleci/config.yml

@@ -65,5 +65,6 @@ jobs:
 workflows:
   main:
     jobs:
+      - build-and-test-py37
       - build-and-test-py38
       - build-and-test-py39

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.9.1'
+__version__ = '0.9.2'

+ 26 - 10
hivemind/client/averaging/__init__.py

@@ -21,7 +21,7 @@ from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
     serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
-from hivemind.utils.serializer import PickleSerializer, MSGPackSerializer
+from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils import Endpoint, Port, MPFuture, get_logger
 
 # flavour types
@@ -303,7 +303,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         Get the up-to-date trainer state from a peer.
         The state consists of two parts: (metadata, tensors)
 
-         - metadata is a small pickle-serialized entry meant to store scalars and hyperparameters
+         - metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
         chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
@@ -317,13 +317,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
-    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[bytes, Sequence[torch.Tensor]]:
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
-        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        :returns: a tuple of (serialized_metadata, sequence of torch tensors)
         """
         with self.get_tensors() as tensors:
-            return dict(group_key=self.get_group_bits()), tensors
+            return self.serializer.dumps(dict(group_key=self.get_group_bits())), tensors
 
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
@@ -338,8 +338,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             assert trigger == '_TRIGGER_GET_CURRENT_STATE'
             try:
                 state_metadata, state_tensors = self.get_current_state()
-                # note: serialize here to avoid initializing cuda in the guest process
-                state_metadata = PickleSerializer.dumps(state_metadata)
+                # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
+                assert isinstance(state_metadata, bytes)
                 state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
                                       for tensor in state_tensors)
                 future.set_result((state_metadata, state_tensors))
@@ -348,8 +348,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 logger.warning(e)
                 continue
 
-    def load_state_from_peers(self, wait=True) -> Optional[Any]:
-        """ Try to download the latest optimizer state one of the existing peer """
+    def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
+        """
+        Try to download the latest optimizer state one of the existing peer.
+        :returns: on success, return a 2-tuple with (serialized_metadata, tensors), where
+
+        - serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters)
+        - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
+
+        The exact contents of both serialized_metadata and tensors are determined by get_current_state method
+        """
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
         return future.result() if wait else future
@@ -376,7 +384,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     current_tensor_parts, tensors = [], []
                     async for message in stream:
                         if message.metadata:
-                            metadata = PickleSerializer.loads(message.metadata)
+                            metadata = self.serializer.loads(message.metadata)
                         if message.tensor_part.dtype and current_tensor_parts:
                             # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
@@ -398,6 +406,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):
+        """
+        :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
+        :returns: averager's current group key bits (without prefix)
+        """
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_get_group_bits', [], dict(future=_future)))
         return future.result() if wait else future
@@ -406,6 +418,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         future.set_result(self._matchmaking.group_key_manager.group_bits)
 
     def set_group_bits(self, group_bits: str, wait: bool = True):
+        """
+        :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
+        :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
+        """
         future, _future = MPFuture.make_pair()
         assert all(bit in '01' for bit in group_bits)
         self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))

+ 1 - 6
tests/test_averaging.py

@@ -277,7 +277,7 @@ def test_load_state_from_peers():
             """
             nonlocal num_calls, super_metadata, super_tensors
             num_calls += 1
-            return super_metadata, super_tensors
+            return self.serializer.dumps(super_metadata), super_tensors
 
     dht_root = hivemind.DHT(start=True)
     initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
@@ -308,11 +308,6 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
-    # check that normal averaging still works
-    # futures = [averager.step(wait=False) for averager in [averager1, averager2]]
-    # for future in futures:
-    #     future.result()
-
 
 @pytest.mark.forked
 def test_getset_bits():