Parcourir la source

cast incoming tensors to backend dtype

dbaranchuk il y a 3 ans
Parent
commit
beacf25c5c
3 fichiers modifiés avec 25 ajouts et 9 suppressions
  1. 3 2
      src/server/backend.py
  2. 18 4
      src/server/handler.py
  3. 4 3
      src/server/server.py

+ 3 - 2
src/server/backend.py

@@ -1,6 +1,6 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Sequence, Tuple
+from typing import Sequence, Tuple, Optional
 
 import torch
 from hivemind import use_hivemind_log_handler
@@ -44,7 +44,7 @@ class InferenceTaskPool(TaskPool):
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
-    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
@@ -56,6 +56,7 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = InferenceTaskPool(
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
+        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 18 - 4
src/server/handler.py

@@ -81,6 +81,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -93,7 +96,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.type(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
@@ -106,6 +109,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs to backend dtype
+        hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
         # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
@@ -117,7 +123,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall output
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
@@ -134,6 +140,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its output
         inter_inputs = [inputs]
@@ -154,7 +164,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
             ]
         )
@@ -167,6 +177,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
+        # Cast inputs & grad outputs to backend dtype
+        inputs = inputs.to(requested_backends[0].dtype)
+        grads = grads.to(requested_backends[-1].dtype)
+
         # Run a forward chain to collect intermediate inputs
         # Note that we do not forward for the last module since we do not need its outputs
         inter_inputs = [inputs]
@@ -186,7 +200,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
         ]
         # Split the serialized_grad_inputs for streaming and respond

+ 4 - 3
src/server/server.py

@@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import BloomConfig, declare_active_modules
+from src import BloomConfig, declare_active_modules, BloomBlock
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
@@ -194,9 +194,10 @@ class Server(threading.Thread):
                 module_uid,
                 block,
                 memory_cache=memory_cache,
-                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                backend_dtype=None if torch_dtype == 'auto' else torch_dtype,
+                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression),),
                 kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression),),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )