|
@@ -123,9 +123,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
expert = self.experts[request.uid]
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
- tensors=await self._process_inputs(
|
|
|
- inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema
|
|
|
- )
|
|
|
+ tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
)
|
|
|
|
|
|
async def rpc_backward_partial(
|
|
@@ -134,9 +132,14 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
uid, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
expert = self.experts[uid]
|
|
|
output_split = [
|
|
|
- p for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
+ p
|
|
|
+ for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
|
|
|
]
|
|
|
|
|
|
async for part in as_aiter(*output_split):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part, ])
|
|
|
+ yield runtime_pb2.ExpertResponse(
|
|
|
+ tensors=[
|
|
|
+ part,
|
|
|
+ ]
|
|
|
+ )
|