瀏覽代碼

bug fixes

dbaranchuk 3 年之前
父節點
當前提交
5c19edad67
共有 3 個文件被更改,包括 3 次插入6 次删除
  1. 2 2
      src/bloom/block.py
  2. 1 4
      src/server/backend.py
  3. 0 0
      src/server/task_pool.py

+ 2 - 2
src/bloom/block.py

@@ -65,7 +65,7 @@ class BloomAttention(nn.Module):
         head_mask=None,
         use_cache=False,
         output_attentions=False,
-        DEBUG_INPLACE_PAST: bool = True
+        DEBUG_INPLACE_PAST: bool = False
     ):
         if DEBUG_INPLACE_PAST:
             past_key, past_value, past_length = layer_past
@@ -97,7 +97,7 @@ class BloomAttention(nn.Module):
         if DEBUG_INPLACE_PAST:
             past_key, past_value, past_length = layer_past
             assert past_key.dtype == key_layer.dtype
-            assert past_key.shape[1] == 2048
+            # assert past_key.shape[1] == 2048
             assert not torch.is_grad_enabled()
             past_key[:, past_length: past_length + key_layer.shape[1]] = key_layer.type_as(past_key)
             past_value[:, past_length: past_length + value_layer.shape[1]] = value_layer.type_as(past_value)

+ 1 - 4
src/server/backend.py

@@ -14,7 +14,7 @@ from src.server.cache import MemoryCache
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-MAX_LENGTH = 2048
+MAX_LENGTH = 512
 
 
 class InferenceTaskPool(TaskPool):
@@ -59,7 +59,6 @@ class TransformerBackend(ModuleBackend):
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype            
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-        print('START INFERENCE STEP')
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
@@ -72,11 +71,9 @@ class TransformerBackend(ModuleBackend):
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
                 layer_past = cache[0, ...], cache[1, ...], prefix_length
 
-                print("AAA")
                 hidden_states, (new_k, new_v) = self.module.forward(
                     hidden_states, layer_past=layer_past, use_cache=True, DEBUG_INPLACE_PAST=True,
                 )
-                print("BBB")
                 # todo remove these asserts once we pass all tests
                 new_length = new_v.shape[1]
                 assert new_length > prefix_length

+ 0 - 0
src/server/task_pool.py