Artem Chumachenko 3 년 전
부모
커밋
5af1c9e3b4
3개의 변경된 파일8개의 추가작업 그리고 8개의 파일을 삭제
  1. 1 1
      src/client/inference_session.py
  2. 4 2
      src/server/backend.py
  3. 3 5
      src/server/handler.py

+ 1 - 1
src/client/inference_session.py

@@ -91,7 +91,7 @@ class RemoteTransformerBlockInferenceSession:
             assert prompts.shape[3] == new_hidden_states.shape[2]
 
         if hypo_ids is None or is_dummy(hypo_ids):
-            hypo_ids = torch.arange(len(new_hidden_states))
+            hypo_ids = DUMMY
         else:
             assert len(hypo_ids) == len(new_hidden_states)
             assert hypo_ids.dtype == torch.int64

+ 4 - 2
src/server/backend.py

@@ -10,6 +10,7 @@ from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
+from src.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -73,8 +74,9 @@ class TransformerBackend(ModuleBackend):
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
-                layer_past = past_k, past_v = cache[0, :prefix_length], cache[1, :prefix_length]
+                if not is_dummy(hypo_ids):
+                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
+                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 

+ 3 - 5
src/server/handler.py

@@ -95,9 +95,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert isinstance(
                             hidden_states, torch.Tensor
                         ), f"hidden states must be tensor, got {type(hidden_states)}"
-                        assert (
-                            hidden_states.ndim == 3
-                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                        assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
                         (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
 
                     # serialize and send last layer outputs
@@ -105,13 +103,13 @@ class TransformerConnectionHandler(ConnectionHandler):
                         tensors=[
                             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
-                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
                             )
                         ]
                     )
 
                     # prepare for next step
-                    prefix_length += hidden_states[0].shape[1]
+                    prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")