Răsfoiți Sursa

Support various backend dtypes & async serialization (#38)

Dmitry Baranchuk 3 ani în urmă
părinte
comite
04a2b6f5e3

+ 1 - 1
requirements.txt

@@ -2,5 +2,5 @@ torch==1.12.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
 bitsandbytes-cuda113==0.26.0
-https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
+https://github.com/learning-at-home/hivemind/archive/28261470e44f2ae4157d08b563b4d2771f3a9549.zip
 https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 0 - 1
src/bloom/model.py

@@ -584,7 +584,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
                 )
 
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-
         loss = None
         if labels is not None:
             if self.config.problem_type is None:

+ 7 - 7
src/client/remote_model.py

@@ -32,7 +32,7 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
-    num_prefix_tokens: int = 0  # a number of tokens for prompt tuning.
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
 
 
 class DistributedBloomModel(BloomModel):
@@ -110,11 +110,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
     def __init__(self, config):
         super().__init__(config)
-        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
-        self.prefix_length = config.num_prefix_tokens
+        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+        self.pre_seq_len = config.pre_seq_len
 
-        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.prefix_length).long()
+        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
 
     def get_prompt(self, batch_size):
         prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
@@ -161,7 +161,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)
@@ -194,7 +194,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)

+ 17 - 9
src/client/sequential_autograd.py

@@ -39,14 +39,17 @@ async def run_expert_forward(
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
-    # TODO: figure out whether we should use run_in_executor here
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, proto.compression)
-        for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+        )
     )
+
     deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
     flat_outputs = tuple(deserialized_outputs)
-
     return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -67,10 +70,15 @@ async def run_expert_backward(
     inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
 
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, proto.compression)
-        for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
     )
+
     deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
     return deserialized_grad_inputs
 
@@ -187,7 +195,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     @staticmethod
     def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
-        input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
+        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
 
         sequence_manager.rpc_info  # lazy init
         outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))

+ 3 - 2
src/server/backend.py

@@ -1,6 +1,6 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Sequence, Tuple
+from typing import Optional, Sequence, Tuple
 
 import torch
 from hivemind import use_hivemind_log_handler
@@ -44,7 +44,7 @@ class InferenceTaskPool(TaskPool):
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
-    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
@@ -56,6 +56,7 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = InferenceTaskPool(
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
+        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 19 - 4
src/server/handler.py

@@ -81,6 +81,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -93,7 +96,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
@@ -106,6 +109,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -117,7 +123,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall output
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
@@ -134,6 +140,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its output
         inter_inputs = [inputs]
@@ -154,7 +164,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
             ]
         )
@@ -162,11 +172,16 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+
         uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
         inputs, grads = inputs_and_grads
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its outputs
         inter_inputs = [inputs]
@@ -186,7 +201,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
         ]
         # Split the serialized_grad_inputs for streaming and respond

+ 11 - 2
src/server/server.py

@@ -194,9 +194,18 @@ class Server(threading.Thread):
                 module_uid,
                 block,
                 memory_cache=memory_cache,
-                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                args_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                outputs_schema=(
+                    BatchTensorDescriptor(
+                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                    ),
+                ),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )