|
@@ -18,7 +18,7 @@ logger = get_logger(__file__)
|
|
|
class TransformerBackend(ModuleBackend):
|
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
|
|
|
|
- def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
|
|
|
+ def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
assert isinstance(self.module, BloomBlock)
|
|
|
self.memory_cache = memory_cache
|
|
@@ -37,7 +37,9 @@ class TransformerBackend(ModuleBackend):
|
|
|
self.backward_pool = PrioritizedTaskPool(
|
|
|
self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
|
|
|
)
|
|
|
- self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
|
|
|
+
|
|
|
+ assert backend_dtype is not None
|
|
|
+ self.dtype = backend_dtype
|
|
|
self.inference_schema = (
|
|
|
(
|
|
|
*self.args_schema,
|