|
@@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
-from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
|
+from src.server.backend import TransformerBackend
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
@@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
module_backends: Dict[ModuleUID, TransformerBackend]
|
|
|
|
|
|
- def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
|
|
|
+ def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
|
|
|
super().__init__(dht, module_backends)
|
|
|
for module_backend in self.module_backends.values():
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
+ self.inference_max_length = inference_max_length
|
|
|
|
|
|
async def rpc_inference(
|
|
|
self,
|
|
@@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
print("OPENED RPC_INFERENCE")
|
|
|
request = await anext(requests)
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ max_length = metadata.get("max_length")
|
|
|
+
|
|
|
+ if not requested_uids:
|
|
|
+ raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
+ assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
|
|
|
+ if max_length not in range(0, self.inference_max_length):
|
|
|
+ raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
|
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
|
|
@@ -52,10 +61,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
) # [cache_handle, prefix_length]
|
|
|
prefix_length = 0
|
|
|
|
|
|
- async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
|
|
|
+ async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
|
|
|
+
|
|
|
+ if prefix_length + length_increment > max_length:
|
|
|
+ raise ValueError(f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
|
+ f" exceeds pre-allocated maximum {max_length}")
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
@@ -113,7 +127,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
|
|
|
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
|
|
|
|
# Serialize the overall output
|
|
|
serialized_output = [
|
|
@@ -193,7 +207,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
return tuple(uids)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
- async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
|
|
|
+ async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int) -> Sequence[int]:
|
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
|
handles = []
|
|
@@ -202,7 +216,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
|
cache_descriptor = TensorDescriptor(
|
|
|
- size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
|
|
|
+ size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
|
|
|
)
|
|
|
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
|
|
|