Переглянути джерело

add pseudo efficient inference: debug

dbaranchuk 3 роки тому
батько
коміт
443b2c0261

+ 19 - 2
src/bloom/block.py

@@ -65,9 +65,14 @@ class BloomAttention(nn.Module):
         head_mask=None,
         use_cache=False,
         output_attentions=False,
+        DEBUG_INPLACE_PAST: bool = True
     ):
+        if DEBUG_INPLACE_PAST:
+            past_key, past_value, past_length = layer_past
+            current_sequence_length = hidden_states.shape[1] + past_length
+        else:
+            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])    
         if alibi is None:
-            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
             alibi = build_alibi_tensor(
                 current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
             )
@@ -89,7 +94,17 @@ class BloomAttention(nn.Module):
         # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3  [batch_size, seq_length, num_heads, head_dim]
         (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
 
-        if layer_past is not None:
+        if DEBUG_INPLACE_PAST:
+            past_key, past_value, past_length = layer_past
+            assert past_key.dtype == key_layer.dtype
+            assert past_key.shape[1] == 2048
+            assert not torch.is_grad_enabled()
+            past_key[:, past_length: past_length + key_layer.shape[1]] = key_layer.type_as(past_key)
+            past_value[:, past_length: past_length + value_layer.shape[1]] = value_layer.type_as(past_value)
+            key_layer = past_key[:, :current_sequence_length, ...]
+            value_layer = past_value[:, :current_sequence_length, ...]
+        elif layer_past is not None:
+            assert False, "TODO ENABLE INPLACE"
             past_key, past_value = layer_past
             key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
             value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
@@ -208,6 +223,7 @@ class BloomBlock(nn.Module):
         use_cache=False,
         output_attentions=False,
         alibi=None,
+        DEBUG_INPLACE_PAST=True
     ):
         # hidden_states: [batch_size, seq_length, hidden_size]
 
@@ -230,6 +246,7 @@ class BloomBlock(nn.Module):
             head_mask=head_mask,
             use_cache=use_cache,
             output_attentions=output_attentions,
+            DEBUG_INPLACE_PAST=DEBUG_INPLACE_PAST
         )
 
         attention_output = attn_outputs[0]

+ 1 - 1
src/client/inference_session.py

@@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
+                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
                 )

+ 14 - 12
src/server/backend.py

@@ -56,9 +56,10 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = InferenceTaskPool(
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
-        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype            
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        print('START INFERENCE STEP')
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
@@ -69,22 +70,23 @@ class TransformerBackend(ModuleBackend):
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-                print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
+                layer_past = cache[0, ...], cache[1, ...], prefix_length
+
+                print("AAA")
                 hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True
+                    hidden_states, layer_past=layer_past, use_cache=True, DEBUG_INPLACE_PAST=True,
                 )
-
+                print("BBB")
                 # todo remove these asserts once we pass all tests
                 new_length = new_v.shape[1]
                 assert new_length > prefix_length
-                assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
-                assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
-                assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-                assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
-                assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
-                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
-                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+                # assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
+                # assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
+                # assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
+                # assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
+                # assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
+                # cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+                # cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
                 return (hidden_states,)
 
     def get_pools(self) -> Sequence[TaskPool]:

+ 5 - 2
src/server/handler.py

@@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+                    
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
-                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
                                 hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
                             )
@@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 

+ 0 - 0
src/server/task_pool.py