ソースを参照

Add --{request,session,step}_timeout

Aleksandr Borzunov 2 年 前
コミット
ebf07d33ed
5 ファイル変更259 行追加218 行削除
  1. 6 0
      cli/run_server.py
  2. 1 0
      requirements.txt
  3. 10 6
      src/server/cache.py
  4. 218 197
      src/server/handler.py
  5. 24 15
      src/server/server.py

+ 6 - 0
cli/run_server.py

@@ -75,6 +75,12 @@ def main():
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
+    parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
+                        help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
+    parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
+                        help='Timeout for the whole inference session')
+    parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
+                        help="Timeout for waiting the next step's inputs inside an inference session")
 
     group = parser.add_mutually_exclusive_group()
     group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,

+ 1 - 0
requirements.txt

@@ -6,3 +6,4 @@ transformers==4.21.3
 protobuf>=3.20.3,<4.0dev
 git+https://github.com/learning-at-home/hivemind@be88b4280cdd87432168e1da238e532f1364078b
 humanfriendly
+async-timeout>=4.0.2

+ 10 - 6
src/server/cache.py

@@ -76,7 +76,9 @@ class MemoryCache:
         try:
             async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout)
+                    await loop.run_in_executor(
+                        None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout
+                    )
                 async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                     allocated_handle = int(self.handle_counter)
                     self.current_size_bytes += allocated_size_bytes
@@ -93,17 +95,19 @@ class MemoryCache:
                     self.current_size_bytes -= allocated_size_bytes
                 self._memory_freed_event.set()
 
-    def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+    def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
         # note: this function should only be called inside _lock_acquire_memory!
-        if allocated_size_bytes > self.max_size_bytes:
+        if allocated_size > self.max_size_bytes:
             raise AllocationFailed(
-                f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
+                f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
             )
         deadline = None if timeout is None else time.perf_counter() + timeout
-        while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+        while self.current_size_bytes + allocated_size > self.max_size_bytes:
             remaining_time = deadline - time.perf_counter() if timeout is not None else None
             if not self._memory_freed_event.wait(remaining_time):
-                raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+                raise AllocationFailed(
+                    f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
+                )
             self._memory_freed_event.clear()
 
     @contextlib.contextmanager

+ 218 - 197
src/server/handler.py

@@ -1,7 +1,9 @@
+import asyncio
 import contextlib
 from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
 
 import torch
+from async_timeout import timeout
 from hivemind import (
     DHT,
     MSGPackSerializer,
@@ -37,13 +39,19 @@ class TransformerConnectionHandler(ConnectionHandler):
         self,
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
+        *,
         inference_max_length: int,
+        request_timeout: float,
+        session_timeout: float,
+        step_timeout: float,
         task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
     ):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
         self.inference_max_length = inference_max_length
+        self.request_timeout = request_timeout
+        self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self._prioritizer = task_prioritizer
 
     async def _gather_inputs(
@@ -76,227 +84,240 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
 
-        request = await anext(requests)
-        requested_uids = self._check_uids(request.uid)
-        self._log_request("rpc_inference.open", requested_uids, context)
-        try:
-            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-            max_length = metadata.get("max_length")
-            points = metadata.get("points", 0)
+        async with timeout(self.session_timeout):
+            request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_inference.open", requested_uids, context)
+            try:
+                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+                max_length = metadata.get("max_length")
+                points = metadata.get("points", 0)
+
+                if not requested_uids:
+                    raise ValueError("User must specify at least one block for inference, but got none")
+                assert isinstance(
+                    max_length, int
+                ), f"rpc_inference metadata must contain int max_length, got {max_length}"
+                assert isinstance(
+                    points, (float, int)
+                ), f"rpc_inference should have number of points as a number or None, got {points}"
+                if not 0 <= max_length <= self.inference_max_length:
+                    raise ValueError(
+                        f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
+                    )
 
-            if not requested_uids:
-                raise ValueError("User must specify at least one block for inference, but got none")
-            assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
-            assert isinstance(
-                points, (float, int)
-            ), f"rpc_inference should have number of points as a number or None, got {points}"
-            if not 0 <= max_length <= self.inference_max_length:
-                raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
-
-            point_per_piece = points / max_length if max_length > 0 else 0.0
-            batch_size = request.tensors[0].size[0] if request.tensors else 1
-
-            cache_metadata = torch.tensor(
-                [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
-            )  # [cache_handle, prefix_length]
-            prefix_length = 0
-
-            async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
-                assert len(cache_handles) == len(requested_backends)
-                while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-
-                    # Cast inputs to backend dtype
-                    hidden_states = hidden_states.to(requested_backends[0].dtype)
-                    assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
-
-                    # parse deep prompts (optional argument)
-                    if prompts is None or is_dummy(prompts) or is_dummy(prompts):
-                        prompts = [DUMMY] * len(requested_backends)
-                    else:
-                        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-
-                    if not (len(requested_backends) == len(prompts)):
-                        raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
-
-                    length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
-                    if prefix_length + length_increment > max_length:
-                        raise ValueError(
-                            f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
-                            f" exceeds pre-allocated maximum {max_length}"
-                        )
+                point_per_piece = points / max_length if max_length > 0 else 0.0
+                batch_size = request.tensors[0].size[0] if request.tensors else 1
 
-                    # run request tensors through all requested modules, update caches
-                    for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
-                        if not is_dummy(prompt):
-                            hidden_states[:, : prompt.shape[1]] += prompt
-
-                        cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
-                        assert isinstance(
-                            hidden_states, torch.Tensor
-                        ), f"hidden states must be tensor, got {type(hidden_states)}"
-                        assert (
-                            hidden_states.ndim == 3
-                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        assert isinstance(
-                            backend.inference_pool, PrioritizedTaskPool
-                        ), "petals support only prioritized pools"
-                        priority = self._prioritizer.prioritize(
-                            cache_metadata,
-                            hidden_states,
-                            hypo_ids,
-                            points=point_per_piece / len(requested_backends),
-                            backend=backend,
-                            type="inference",
-                        )
-                        (hidden_states,) = await backend.inference_pool.submit_task(
-                            cache_metadata, hidden_states, hypo_ids, priority=priority
-                        )
+                cache_metadata = torch.tensor(
+                    [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+                )  # [cache_handle, prefix_length]
+                prefix_length = 0
 
-                    # serialize and send last layer outputs
-                    yield runtime_pb2.ExpertResponse(
-                        tensors=[
-                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                            for result, proto in zip(
-                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                            )
+                async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
+                    assert len(cache_handles) == len(requested_backends)
+                    while request.tensors:  # iterate while user is willing to supply tensors
+                        hidden_states, prompts, hypo_ids = [
+                            deserialize_torch_tensor(tensor) for tensor in request.tensors
                         ]
-                    )
 
-                    # prepare for next step
-                    prefix_length += hidden_states.shape[1]
-                    request = await (anext(requests))
-        finally:
-            self._log_request("rpc_inference.close", requested_uids, context)
+                        # Cast inputs to backend dtype
+                        hidden_states = hidden_states.to(requested_backends[0].dtype)
+                        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+                        # parse deep prompts (optional argument)
+                        if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+                            prompts = [DUMMY] * len(requested_backends)
+                        else:
+                            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+                        if not (len(requested_backends) == len(prompts)):
+                            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
+                        if prefix_length + length_increment > max_length:
+                            raise ValueError(
+                                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                                f" exceeds pre-allocated maximum {max_length}"
+                            )
+
+                        # run request tensors through all requested modules, update caches
+                        for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                            if not is_dummy(prompt):
+                                hidden_states[:, : prompt.shape[1]] += prompt
+
+                            cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                            assert isinstance(
+                                hidden_states, torch.Tensor
+                            ), f"hidden states must be tensor, got {type(hidden_states)}"
+                            assert (
+                                hidden_states.ndim == 3
+                            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                            assert isinstance(
+                                backend.inference_pool, PrioritizedTaskPool
+                            ), "petals support only prioritized pools"
+                            priority = self._prioritizer.prioritize(
+                                cache_metadata,
+                                hidden_states,
+                                hypo_ids,
+                                points=point_per_piece / len(requested_backends),
+                                backend=backend,
+                                type="inference",
+                            )
+                            (hidden_states,) = await backend.inference_pool.submit_task(
+                                cache_metadata, hidden_states, hypo_ids, priority=priority
+                            )
+
+                        # serialize and send last layer outputs
+                        yield runtime_pb2.ExpertResponse(
+                            tensors=[
+                                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                                for result, proto in zip(
+                                    (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
+                                )
+                            ]
+                        )
+
+                        # prepare for next step
+                        prefix_length += hidden_states.shape[1]
+                        request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            finally:
+                self._log_request("rpc_inference.close", requested_uids, context)
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
-        # Parse request and prepare backends
-        flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_uids(request.uid)
-        self._log_request("rpc_forward", requested_uids, context)
-
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("points", 0)
-        assert isinstance(
-            points, (float, int)
-        ), f"rpc_forward should have number of points as number or None, got {points}"
-
-        hidden_states = await _rpc_forward(
-            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
-        )
-        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+        async with timeout(self.request_timeout):
+            # Parse request and prepare backends
+            flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_forward", requested_uids, context)
 
-        # Serialize output and respond to client
-        return runtime_pb2.ExpertResponse(
-            tensors=[
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-            ]
-        )
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_forward should have number of points as number or None, got {points}"
+
+            hidden_states = await _rpc_forward(
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+            assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+
+            # Serialize output and respond to client
+            return runtime_pb2.ExpertResponse(
+                tensors=[
+                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                    for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+                ]
+            )
 
     async def rpc_forward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
-        # Parse requests and prepare backends
-        uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
-        requested_uids = self._check_uids(uid_str)
-        self._log_request("rpc_forward_stream", requested_uids, context)
-
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        points = metadata.get("points", 0)
-        assert isinstance(
-            points, (float, int)
-        ), f"rpc_forward_stream should have number of points as number or None, got {points}"
-
-        hidden_states = await _rpc_forward(
-            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
-        )
-        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
+        async with timeout(self.request_timeout):
+            # Parse requests and prepare backends
+            uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
+            requested_uids = self._check_uids(uid_str)
+            self._log_request("rpc_forward_stream", requested_uids, context)
 
-        # Serialize the overall output
-        serialized_output = [
-            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-        ]
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
-        # Split the serialized_output for streaming and respond to client
-        output_split = [
-            part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-        ]
-        async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part])
+            hidden_states = await _rpc_forward(
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+            assert (
+                isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+            ), "hidden_states must be a 3d tensor"
+
+            # Serialize the overall output
+            serialized_output = [
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+            ]
+
+            # Split the serialized_output for streaming and respond to client
+            output_split = [
+                part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            ]
+            async for part in as_aiter(*output_split):
+                yield runtime_pb2.ExpertResponse(tensors=[part])
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
-        # Parse requests and prepare backends
-        flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_uids(request.uid)
-        self._log_request("rpc_backward", requested_uids, context)
-
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        points = metadata.get("points", 0)
-        assert isinstance(
-            points, (float, int)
-        ), f"rpc_backward should have number of points as number or None, got {points}"
-
-        grads = await _rpc_backward(
-            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
-        )
+        async with timeout(self.request_timeout):
+            # Parse requests and prepare backends
+            flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_backward", requested_uids, context)
 
-        # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_backward should have number of points as number or None, got {points}"
 
-        grad_inputs_schema_with_prompts = (
-            requested_backends[0].args_schema * len(grads),
-            requested_backends[0].kwargs_schema,
-        )  # TODO generalize
+            grads = await _rpc_backward(
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
 
-        # Serialize the overall grad_input and respond
-        return runtime_pb2.ExpertResponse(
-            tensors=[
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-            ]
-        )
+            # Modify grad_inputs_schema to support grad_prompts
+            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+
+            grad_inputs_schema_with_prompts = (
+                requested_backends[0].args_schema * len(grads),
+                requested_backends[0].kwargs_schema,
+            )  # TODO generalize
+
+            # Serialize the overall grad_input and respond
+            return runtime_pb2.ExpertResponse(
+                tensors=[
+                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                    for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+                ]
+            )
 
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
-        uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
-        requested_uids = self._check_uids(uids_header)
-        self._log_request("rpc_backward_stream", requested_uids, context)
-
-        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-        points = metadata.get("points", 0)
-        assert isinstance(
-            points, (float, int)
-        ), f"rpc_backward_stream should have number of points as number or None, got {points}"
-
-        grads = await _rpc_backward(
-            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
-        )
+        async with timeout(self.request_timeout):
+            uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
+            requested_uids = self._check_uids(uids_header)
+            self._log_request("rpc_backward_stream", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_backward_stream should have number of points as number or None, got {points}"
+
+            grads = await _rpc_backward(
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+
+            # Modify grad_inputs_schema to support grad_prompts
+            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+            grad_inputs_schema_with_prompts = (
+                requested_backends[0].args_schema * len(grads),
+                requested_backends[0].kwargs_schema,
+            )  # TODO generalize
+
+            # Serialize the overall grad_inputs
+            serialized_grad_inputs = [
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+            ]
+            # Split the serialized_grad_inputs for streaming and respond
+            output_split = [
+                part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            ]
 
-        # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-        grad_inputs_schema_with_prompts = (
-            requested_backends[0].args_schema * len(grads),
-            requested_backends[0].kwargs_schema,
-        )  # TODO generalize
-
-        # Serialize the overall grad_inputs
-        serialized_grad_inputs = [
-            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-        ]
-        # Split the serialized_grad_inputs for streaming and respond
-        output_split = [
-            part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-        ]
-
-        async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part])
+            async for part in as_aiter(*output_split):
+                yield runtime_pb2.ExpertResponse(tensors=[part])
 
     def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""

+ 24 - 15
src/server/server.py

@@ -62,6 +62,9 @@ class Server:
         custom_module_path=None,
         update_period: float = 30,
         expiration: Optional[float] = None,
+        request_timeout: float = 3 * 60,
+        session_timeout: float = 30 * 60,
+        step_timeout: float = 5 * 60,
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         balance_quality: float = 0.75,
@@ -101,6 +104,9 @@ class Server:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.expiration = expiration
 
+        self.request_timeout = request_timeout
+        self.session_timeout, self.step_timeout = session_timeout, step_timeout
+
         self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
@@ -168,6 +174,9 @@ class Server:
                 stats_report_interval=self.stats_report_interval,
                 update_period=self.update_period,
                 expiration=self.expiration,
+                request_timeout=self.request_timeout,
+                session_timeout=self.session_timeout,
+                step_timeout=self.step_timeout,
                 prefetch_batches=self.prefetch_batches,
                 sender_threads=self.sender_threads,
                 use_auth_token=self.use_auth_token,
@@ -239,22 +248,17 @@ class ModuleContainer(threading.Thread):
         memory_cache: MemoryCache,
         throughput: float,
         block_indices: List[int],
-        num_handlers: Optional[int],
         min_batch_size: int,
         max_batch_size: int,
-        inference_max_length: int,
         torch_dtype: torch.dtype,
         cache_dir: Optional[str],
         device: Union[str, torch.device],
         compression: CompressionType,
-        stats_report_interval: Optional[int],
         update_period: float,
         expiration: Optional[float],
-        prefetch_batches: int,
-        sender_threads: int,
         use_auth_token: Optional[str],
         load_in_8bit: bool,
-        start: bool,
+        **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
         joining_announcer = ModuleAnnouncerThread(
@@ -328,15 +332,10 @@ class ModuleContainer(threading.Thread):
             dht,
             blocks,
             throughput=throughput,
-            num_connection_handlers=num_handlers,
-            inference_max_length=inference_max_length,
             device=device,
-            stats_report_interval=stats_report_interval,
             update_period=update_period,
             expiration=expiration,
-            prefetch_batches=prefetch_batches,
-            sender_threads=sender_threads,
-            start=start,
+            **kwargs,
         )
 
     def __init__(
@@ -345,10 +344,13 @@ class ModuleContainer(threading.Thread):
         module_backends: Dict[str, TransformerBackend],
         *,
         inference_max_length: int,
-        num_connection_handlers: int,
+        num_handlers: int,
         throughput: float,
         update_period: float,
         expiration: Optional[float] = None,
+        request_timeout: float,
+        session_timeout: float,
+        step_timeout: float,
         start: bool,
         **kwargs,
     ):
@@ -357,8 +359,15 @@ class ModuleContainer(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
-            for _ in range(num_connection_handlers)
+            TransformerConnectionHandler(
+                dht,
+                self.module_backends,
+                inference_max_length=inference_max_length,
+                request_timeout=request_timeout,
+                session_timeout=session_timeout,
+                step_timeout=step_timeout,
+            )
+            for _ in range(num_handlers)
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
         self.online_announcer = ModuleAnnouncerThread(