|
@@ -7,12 +7,13 @@ from typing import AsyncIterator, List, Optional
|
|
|
import torch
|
|
|
from hivemind import (
|
|
|
P2P,
|
|
|
+ MSGPackSerializer,
|
|
|
anext,
|
|
|
deserialize_torch_tensor,
|
|
|
get_logger,
|
|
|
nested_flatten,
|
|
|
serialize_torch_tensor,
|
|
|
- use_hivemind_log_handler, MSGPackSerializer,
|
|
|
+ use_hivemind_log_handler,
|
|
|
)
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import StubBase
|
|
@@ -33,8 +34,15 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
:note: this inference session is *not* fault-tolerant out of the box
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator,
|
|
|
- *, max_length: int):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ uid: ModuleUID,
|
|
|
+ rpc_info: RPCInfo,
|
|
|
+ inputs_queue: asyncio.Queue,
|
|
|
+ outputs_aiter: AsyncIterator,
|
|
|
+ *,
|
|
|
+ max_length: int,
|
|
|
+ ):
|
|
|
self.uid, self.rpc_info = uid, rpc_info
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
@@ -75,7 +83,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
|
|
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
|
|
|
],
|
|
|
- metadata=self._serialized_metadata if not self.stepped else None
|
|
|
+ metadata=self._serialized_metadata if not self.stepped else None,
|
|
|
)
|
|
|
)
|
|
|
)
|
|
@@ -145,7 +153,10 @@ class RemoteSequentialInferenceSession:
|
|
|
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
|
|
|
inference_session = RemoteExpertWorker.run_coroutine(
|
|
|
RemoteTransformerBlockInferenceSession._create(
|
|
|
- stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout,
|
|
|
+ stub,
|
|
|
+ span_uids,
|
|
|
+ rpc_info=self.sequence_manager.rpc_info,
|
|
|
+ timeout=self.timeout,
|
|
|
)
|
|
|
)
|
|
|
self.inference_sessions.append(inference_session)
|