|
@@ -14,7 +14,7 @@ from src.server.cache import MemoryCache
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
-MAX_LENGTH = 2048
|
|
|
+MAX_LENGTH = 512
|
|
|
|
|
|
|
|
|
class InferenceTaskPool(TaskPool):
|
|
@@ -59,7 +59,6 @@ class TransformerBackend(ModuleBackend):
|
|
|
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
- print('START INFERENCE STEP')
|
|
|
with torch.inference_mode():
|
|
|
attention_cache_handle = int(cache_metadata[0, 0].item())
|
|
|
prefix_length = int(cache_metadata[0, 1].item())
|
|
@@ -72,11 +71,9 @@ class TransformerBackend(ModuleBackend):
|
|
|
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
|
|
|
layer_past = cache[0, ...], cache[1, ...], prefix_length
|
|
|
|
|
|
- print("AAA")
|
|
|
hidden_states, (new_k, new_v) = self.module.forward(
|
|
|
hidden_states, layer_past=layer_past, use_cache=True, DEBUG_INPLACE_PAST=True,
|
|
|
)
|
|
|
- print("BBB")
|
|
|
# todo remove these asserts once we pass all tests
|
|
|
new_length = new_v.shape[1]
|
|
|
assert new_length > prefix_length
|