|
@@ -46,15 +46,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def _gather_inputs(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> Tuple[str, List[torch.Tensor], Dict]:
|
|
|
- expert_uid, metadata = None, None
|
|
|
+ block_uid, metadata = None, None
|
|
|
|
|
|
def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
|
|
|
- nonlocal expert_uid, metadata
|
|
|
+ nonlocal block_uid, metadata
|
|
|
|
|
|
- if expert_uid is None:
|
|
|
- expert_uid = req.uid
|
|
|
- elif expert_uid != req.uid:
|
|
|
- raise ValueError("Expert uids differ in one request")
|
|
|
+ if block_uid is None:
|
|
|
+ block_uid = req.uid
|
|
|
+ elif block_uid != req.uid:
|
|
|
+ raise ValueError("Block uids differ in one request")
|
|
|
|
|
|
if metadata is None:
|
|
|
metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
|
|
@@ -63,7 +63,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
tensors_stream = amap_in_executor(_unpack, requests)
|
|
|
inputs = await deserialize_tensor_stream(tensors_stream)
|
|
|
- return expert_uid, inputs, metadata
|
|
|
+ assert isinstance(block_uid, str) and isinstance(metadata, dict)
|
|
|
+ return block_uid, inputs, metadata
|
|
|
|
|
|
async def rpc_inference(
|
|
|
self,
|