|
@@ -26,7 +26,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
|
|
|
async def rpc_inference(
|
|
|
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
+ self,
|
|
|
+ requests: AsyncIterator[runtime_pb2.ExpertRequest],
|
|
|
+ context: P2PContext,
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
try:
|
|
@@ -35,17 +37,19 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_header(request)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
|
|
|
+ batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
+
|
|
|
+ cache_metadata = torch.tensor([[-1, -1] for _ in range(batch_size)], dtype=torch.int64) # [cache_handle, prefix_length]
|
|
|
prefix_length = 0
|
|
|
|
|
|
- async with self._allocate_caches(requested_backends) as cache_handles:
|
|
|
+ async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
for backend, cache_handle in zip(requested_backends, cache_handles):
|
|
|
- cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
|
|
|
+ cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
assert (
|
|
|
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
@@ -213,7 +217,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
return tuple(uids)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
- async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
|
|
|
+ async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
|
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
|
handles = []
|
|
@@ -221,7 +225,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
num_heads = backend.module.self_attention.num_heads
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
|
- cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
|
|
|
+ cache_descriptor = TensorDescriptor(size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
|
|
|
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
|