Aleksandr Borzunov 2 år sedan
förälder
incheckning
a005b93678
2 ändrade filer med 8 tillägg och 5 borttagningar
  1. 4 2
      src/server/backend.py
  2. 4 3
      src/server/server.py

+ 4 - 2
src/server/backend.py

@@ -18,7 +18,7 @@ logger = get_logger(__file__)
 class TransformerBackend(ModuleBackend):
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
 
-    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
         self.memory_cache = memory_cache
@@ -37,7 +37,9 @@ class TransformerBackend(ModuleBackend):
         self.backward_pool = PrioritizedTaskPool(
         self.backward_pool = PrioritizedTaskPool(
             self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
             self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
         )
         )
-        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+
+        assert backend_dtype is not None
+        self.dtype = backend_dtype
         self.inference_schema = (
         self.inference_schema = (
             (
             (
                 *self.args_schema,
                 *self.args_schema,

+ 4 - 3
src/server/server.py

@@ -292,20 +292,21 @@ class ModuleContainer(threading.Thread):
                 for param in block.parameters():
                 for param in block.parameters():
                     param.requires_grad = False
                     param.requires_grad = False
 
 
+                backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     module_uid,
                     block,
                     block,
                     memory_cache=memory_cache,
                     memory_cache=memory_cache,
-                    backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                    backend_dtype=backend_dtype,
                     args_schema=(
                     args_schema=(
                         BatchTensorDescriptor(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
                         ),
                     ),
                     ),
                     kwargs_schema={},
                     kwargs_schema={},
                     outputs_schema=(
                     outputs_schema=(
                         BatchTensorDescriptor(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
                         ),
                     ),
                     ),
                     min_batch_size=min_batch_size,
                     min_batch_size=min_batch_size,