|
@@ -31,6 +31,7 @@ class RemoteTransformerBlock(RemoteExpert):
|
|
|
|
|
|
class RemoteTransformerBlockInferenceSession:
|
|
|
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
|
|
+
|
|
|
def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
|
|
|
self.uid, self.info = uid, info
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
@@ -41,7 +42,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
|
@classmethod
|
|
|
async def _create(
|
|
|
- cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
|
|
|
+ cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
|
|
|
) -> RemoteTransformerBlockInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
inputs_queue = asyncio.Queue()
|
|
@@ -64,12 +65,17 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
# serialize inputs and put them into the queue
|
|
|
inputs = (new_hidden_states,)
|
|
|
- outputs_serialized = RemoteExpertWorker.run_coroutine(self._step(
|
|
|
- runtime_pb2.ExpertRequest(uid=self.uid, tensors=[
|
|
|
- serialize_torch_tensor(tensor, proto.compression)
|
|
|
- for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
|
|
|
- ])
|
|
|
- ))
|
|
|
+ outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
+ self._step(
|
|
|
+ runtime_pb2.ExpertRequest(
|
|
|
+ uid=self.uid,
|
|
|
+ tensors=[
|
|
|
+ serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
|
|
|
+ ],
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
|
|
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
|
|
return outputs[0]
|
|
@@ -119,10 +125,11 @@ def get_remote_module(
|
|
|
"""
|
|
|
assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
infos = dht.run_coroutine(
|
|
|
- partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time),
|
|
|
- return_future)
|
|
|
+ partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time), return_future
|
|
|
+ )
|
|
|
|
|
|
if return_future:
|
|
|
+
|
|
|
async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
p2p = await dht.replicate_p2p()
|
|
|
return _create_remote_modules_from_infos(await infos_future, p2p)
|
|
@@ -148,8 +155,9 @@ async def _get_remote_module_infos(
|
|
|
return experts
|
|
|
|
|
|
|
|
|
-def _create_remote_modules_from_infos(infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
|
|
- ) -> List[Optional[RemoteTransformerBlock]]:
|
|
|
+def _create_remote_modules_from_infos(
|
|
|
+ infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
|
|
+) -> List[Optional[RemoteTransformerBlock]]:
|
|
|
experts: List[Optional[RemoteTransformerBlock]] = []
|
|
|
for info in infos:
|
|
|
if info is not None:
|