Browse Source

Fix pickle vulnerability (#386)

Denis Mazur 3 years ago
parent
commit
b44236972a

+ 2 - 3
hivemind/moe/client/expert.py

@@ -1,4 +1,3 @@
-import pickle
 from typing import Any, Dict, Optional, Tuple
 
 import torch
@@ -7,7 +6,7 @@ from torch.autograd.function import once_differentiable
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -60,7 +59,7 @@ class RemoteExpert(nn.Module):
     def info(self):
         if self._info is None:
             outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = pickle.loads(outputs.serialized_info)
+            self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):

+ 2 - 3
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,5 @@
 import multiprocessing as mp
 import os
-import pickle
 from typing import Dict
 
 import grpc
@@ -9,7 +8,7 @@ import torch
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, get_logger, nested_flatten
+from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
@@ -61,7 +60,7 @@ class ConnectionHandler(mp.context.ForkProcess):
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
-        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]

+ 50 - 1
hivemind/utils/tensor_descr.py

@@ -8,6 +8,7 @@ import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.serializer import MSGPackSerializer
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
@@ -52,6 +53,18 @@ class TensorDescriptor(DescriptorBase):
         return torch.empty(**properties)
 
 
+def _str_to_torch_type(name: str, torch_type: type):
+    try:
+        value = getattr(torch, name.split(".")[-1])
+    except AttributeError:
+        raise ValueError(f"Invalid dtype: torch has no attribute {name}")
+    if not isinstance(value, torch_type):
+        raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}")
+
+    return value
+
+
+@MSGPackSerializer.ext_serializable(0x51)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
@@ -70,13 +83,49 @@ class BatchTensorDescriptor(TensorDescriptor):
             device=tensor.device,
             requires_grad=tensor.requires_grad,
             pin_memory=_safe_check_pinned(tensor),
-            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
     def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
+    def packb(self) -> bytes:
+        obj_dict = asdict(self)
+
+        obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None
+        obj_dict["layout"] = str(self.layout) if self.layout is not None else None
+
+        device = obj_dict.pop("device")
+        device_type, device_index = (device.type, device.index) if device is not None else (None, None)
+        obj_dict.update(
+            device_type=device_type,
+            device_index=device_index,
+        )
+
+        return MSGPackSerializer.dumps(obj_dict)
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> BatchTensorDescriptor:
+        obj_dict = MSGPackSerializer.loads(raw)
+
+        if obj_dict["dtype"] is not None:
+            obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype)
+
+        if obj_dict["layout"] is not None:
+            obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout)
+
+        if obj_dict["device_type"] is not None:
+            obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"])
+        else:
+            obj_dict["device"] = None
+
+        del obj_dict["device_type"], obj_dict["device_index"]
+
+        size = obj_dict.pop("size")[1:]
+
+        return BatchTensorDescriptor(*size, **obj_dict)
+
 
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:
     """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error."""

+ 16 - 1
tests/test_util_modules.py

@@ -13,7 +13,7 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
@@ -521,3 +521,18 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
+
+
+def test_batch_tensor_descriptor_msgpack():
+    tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
+    tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
+
+    assert (
+        tensor_descr.size == tensor_descr_roundtrip.size
+        and tensor_descr.dtype == tensor_descr_roundtrip.dtype
+        and tensor_descr.layout == tensor_descr_roundtrip.layout
+        and tensor_descr.device == tensor_descr_roundtrip.device
+        and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad
+        and tensor_descr.pin_memory == tensor_descr.pin_memory
+        and tensor_descr.compression == tensor_descr.compression
+    )