Pārlūkot izejas kodu

Serialization fixes, support attention mask in TransformerEncoder (#126)

* Add attention mask to TransformerEncoderLayer

* Get rid of redundant import in DeterministicDropout

* Don't use compression for inputs that are not floating point

* Fix missing requires_grad_ in deserialize_torch_tensor

* RemoteMixtureOfExperts: send data to cpu before serialization
Max Ryabinin 4 gadi atpakaļ
vecāks
revīzija
0f7c539e14

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.14'
+__version__ = '0.8.15'

+ 4 - 2
hivemind/client/expert.py

@@ -73,7 +73,8 @@ class _RemoteModuleCall(torch.autograd.Function):
     def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
                 info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-        inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
+        # detach to avoid pickling the computation graph
+        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
@@ -90,7 +91,8 @@ class _RemoteModuleCall(torch.autograd.Function):
     @staticmethod
     @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         serialized_tensors = [serialize_torch_tensor(tensor, proto.compression)
                               for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]

+ 20 - 13
hivemind/client/moe.py

@@ -159,7 +159,9 @@ class _RemoteCallMany(torch.autograd.Function):
                 info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
-        flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
+
+        flat_inputs_cpu = [tensor.cpu() for tensor in flat_inputs]
+        flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
         # dispatch tasks to all remote experts collect responses
@@ -167,7 +169,7 @@ class _RemoteCallMany(torch.autograd.Function):
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
                 input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
-                                 flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
+                    flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
                 stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
@@ -182,8 +184,8 @@ class _RemoteCallMany(torch.autograd.Function):
         mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
         mask[alive_ii, alive_jj] = True
 
-        alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
-        # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
+        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
 
         outputs = []
         for response_stacked in alive_flat_outputs_stacked:
@@ -191,10 +193,10 @@ class _RemoteCallMany(torch.autograd.Function):
                 [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
                 dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
             output[alive_ii, alive_jj] = response_stacked
-            outputs.append(output)
+            outputs.append(output.to(flat_inputs[0].device))
 
         # save individual outputs for backward pass
-        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
+        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
         ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
         return (mask,) + tuple(outputs)
 
@@ -203,12 +205,15 @@ class _RemoteCallMany(torch.autograd.Function):
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
         info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
-        alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
+        alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
+
         dummy_grad_mask, *flat_grad_outputs = raw_grads
+        flat_grad_outputs_cpu = [tensor.cpu() for tensor in flat_grad_outputs]
+
         num_samples, max_experts = dummy_grad_mask.shape
 
-        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
-        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
+        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
+        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
         backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
 
         # dispatch tasks to all remote experts, collect responses
@@ -230,17 +235,19 @@ class _RemoteCallMany(torch.autograd.Function):
 
         # assemble responses
         backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
-        survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
-        # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
+
+        survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
+        # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = []
         for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
             grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
-                (num_samples, max_experts, *flat_inputs[i].shape[1:]),
+                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
                 device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
             grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
 
-            grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
+            # sum gradients from each expert
+            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
         return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
 

+ 2 - 3
hivemind/server/expert_backend.py

@@ -4,7 +4,7 @@ import torch
 from torch import nn
 
 from hivemind.server.task_pool import TaskPool
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map,\
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map, \
     BatchTensorDescriptor, DUMMY_BATCH_SIZE
 
 
@@ -106,8 +106,7 @@ class ExpertBackend(nn.Module):
             args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
                     else tensor.detach() for tensor in args]
             kwargs = {input_key: (tensor.detach().requires_grad_(True)
-                                  if tensor.dtype in (torch.half, torch.float, torch.double)
-                                  else tensor.detach())
+                                  if tensor.is_floating_point() else tensor.detach())
                       for input_key, tensor in kwargs.items()}
 
             outputs = self.expert(*args, **kwargs)

+ 4 - 2
hivemind/server/layers/__init__.py

@@ -4,11 +4,13 @@ from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLa
 from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
 
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
-                 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
+                 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, dim_feedforward=4 * hid_dim, nhead=16),
                  'nop': lambda hid_dim: NopExpert(hid_dim),
                  'det_dropout': lambda hid_dim: DeterministicDropoutNetwork(hid_dim, dropout_prob=0.2)}
+
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
-                 'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
+                 'transformer': lambda batch_size, hid_dim:
+                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, hid_dim), dtype=torch.bool)),
                  'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'det_dropout': lambda batch_size, hid_dim:
                  (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}

+ 8 - 2
hivemind/server/layers/common.py

@@ -39,13 +39,19 @@ class TransformerEncoderLayer(nn.Module):
 
         self.activation = torch.nn.GELU()
 
-    def forward(self, src):
-        src2 = self.self_attn(src, src, src)[0]
+    def forward(self, src, src_key_padding_mask=None):
+        # (N, S, E) -> (S, N, E)
+        src = src.transpose(0, 1)
+
+        src2 = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask)[0]
         src = src + self.dropout1(src2)
         src = self.norm1(src)
         src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
         src = src + self.dropout2(src2)
         src = self.norm2(src)
+
+        # (S, N, E) -> (N, S, E)
+        src = src.transpose(0, 1)
         return src
 
 

+ 0 - 1
hivemind/server/layers/dropout.py

@@ -1,5 +1,4 @@
 import torch.autograd
-import torch.nn as nn
 from torch import nn as nn
 
 

+ 8 - 5
hivemind/utils/grpc.py

@@ -144,6 +144,7 @@ FP16_MAX = 65_504
 
 def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
                            allow_inplace=False) -> runtime_pb2.Tensor:
+    assert tensor.device == torch.device('cpu')
     if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
         assert tensor.dtype == torch.float32
 
@@ -195,20 +196,22 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
     # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
     if serialized_tensor.compression == CompressionType.NONE:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
-        tensor = torch.as_tensor(array).view(*serialized_tensor.size).requires_grad_(serialized_tensor.requires_grad)
+        tensor = torch.as_tensor(array).view(*serialized_tensor.size)
     elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
         stats_size = list(serialized_tensor.size)
         stats_size[-1] = 1
         stats_count = np.prod(stats_size)
-        means, stds = serialized_tensor.buffer[-8*stats_count:-4*stats_count], serialized_tensor.buffer[-4*stats_count:]
+        means = serialized_tensor.buffer[-8 * stats_count:-4 * stats_count]
+        stds = serialized_tensor.buffer[-4 * stats_count:]
         means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
         stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
         array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
-        tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
+        tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
     elif serialized_tensor.compression == CompressionType.FLOAT16:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
-        tensor = torch.as_tensor(array).view(*serialized_tensor.size)\
-            .to(torch.float32).requires_grad_(serialized_tensor.requires_grad)
+        tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size)
     else:
         raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
+
+    tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor

+ 2 - 1
hivemind/utils/tensor_descr.py

@@ -49,7 +49,8 @@ class BatchTensorDescriptor(TensorDescriptor):
     def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
                    device=tensor.device, requires_grad=tensor.requires_grad,
-                   pin_memory=torch.cuda.is_available() and tensor.is_pinned(), compression=compression)
+                   pin_memory=torch.cuda.is_available() and tensor.is_pinned(),
+                   compression=compression if tensor.is_floating_point() else CompressionType.NONE)
 
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 3 - 3
scripts/run_server.py

@@ -56,11 +56,11 @@ def main():
     if args.pop('increase_file_limit'):
         increase_file_limit()
 
-    compression_name = args.pop("compression")
-    if compression_name == "MEANSTD":
+    compression_type = args.pop("compression")
+    if compression_type == "MEANSTD":
         compression = CompressionType.MEANSTD_LAST_AXIS_FLOAT16
     else:
-        compression = getattr(CompressionType, compression_name)
+        compression = getattr(CompressionType, compression_type)
 
     try:
         server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True, compression=compression)