|
@@ -1,6 +1,6 @@
|
|
|
import asyncio
|
|
|
import multiprocessing as mp
|
|
|
-from typing import AsyncIterator, Dict, Iterable, List, Tuple, Union
|
|
|
+from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -13,12 +13,28 @@ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
|
|
|
from hivemind.utils.asyncio import switch_to_uvloop
|
|
|
-from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
|
|
|
+from hivemind.utils.grpc import gather_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
+class _RequestUnpacker:
|
|
|
+
|
|
|
+ __slots__ = ("uid",)
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.uid: Optional[str] = None
|
|
|
+
|
|
|
+ def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
|
|
|
+ if self.uid is None:
|
|
|
+ self.uid = request.uid
|
|
|
+ else:
|
|
|
+ assert self.uid == request.uid, "Expert uids differ in one request"
|
|
|
+
|
|
|
+ return request.tensors
|
|
|
+
|
|
|
+
|
|
|
class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
"""
|
|
|
A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
|
|
@@ -59,26 +75,11 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
|
|
|
|
|
|
- class _RequestUnpacker:
|
|
|
-
|
|
|
- __slots__ = ("uid",)
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- self.uid = None
|
|
|
-
|
|
|
- def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
|
|
|
- if self.uid is None:
|
|
|
- self.uid = request.uid
|
|
|
- else:
|
|
|
- assert self.uid == request.uid, "Expert uids differ in one request"
|
|
|
-
|
|
|
- return request.tensors
|
|
|
-
|
|
|
async def _gather_inputs(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> Tuple[str, List[torch.Tensor]]:
|
|
|
- unpacker = self._RequestUnpacker()
|
|
|
- inputs = await gather_from_rpc(requests, unpacker, deserialize_torch_tensor)
|
|
|
+ unpacker = _RequestUnpacker()
|
|
|
+ inputs = await gather_from_streaming(requests, unpacker, deserialize_torch_tensor)
|
|
|
return unpacker.uid, inputs
|
|
|
|
|
|
async def _process_inputs(
|
|
@@ -88,8 +89,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
|
|
|
) -> List[runtime_pb2.Tensor]:
|
|
|
return [
|
|
|
- serialize_torch_tensor(t, p.compression, allow_inplace=True)
|
|
|
- for t, p in zip(await pool.submit_task(*inputs), nested_flatten(schema))
|
|
|
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
|
|
|
]
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
@@ -105,9 +106,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
uid, inputs = await self._gather_inputs(requests, context)
|
|
|
expert = self.experts[uid]
|
|
|
output_split = [
|
|
|
- p
|
|
|
- for t in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
|
|
- for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
|
|
|
+ part
|
|
|
+ for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
|
|
|
]
|
|
|
|
|
|
async for part in as_aiter(*output_split):
|
|
@@ -128,9 +129,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
uid, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
expert = self.experts[uid]
|
|
|
output_split = [
|
|
|
- p
|
|
|
- for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
- for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
|
|
|
+ part
|
|
|
+ for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
|
|
|
]
|
|
|
|
|
|
async for part in as_aiter(*output_split):
|