|
@@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
dht: DHT,
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
*,
|
|
|
+ adapters: Optional[Sequence[str]],
|
|
|
dht_prefix: str,
|
|
|
push_manager: multiprocessing.managers.SyncManager,
|
|
|
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue
|
|
@@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
for module_backend in self.module_backends.values():
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
self.dht_prefix = dht_prefix
|
|
|
+ self.adapters = adapters
|
|
|
self._push_manager = push_manager
|
|
|
self._session_queues = session_queues
|
|
|
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues
|
|
@@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
max_length = metadata.get("max_length")
|
|
|
- active_adapter = metadata.get("active_adapter", "")
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
session_id = metadata.get("session_id")
|
|
|
|
|
@@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- active_adapter = metadata.get("active_adapter", "")
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
@@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- active_adapter = metadata.get("active_adapter", "")
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
@@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- active_adapter = metadata.get("active_adapter", "")
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
@@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self._log_request("rpc_backward_stream", requested_uids, context)
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- active_adapter = metadata.get("active_adapter", "")
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
points = metadata.get("points", 0)
|
|
|
assert isinstance(
|
|
|
points, (float, int)
|
|
@@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
+ def _get_active_adapter(self, metadata: dict) -> str:
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
|
+ if active_adapter and (active_adapter not in self.adapters):
|
|
|
+ raise KeyError(f"adapter {active_adapter} not found")
|
|
|
+ return active_adapter
|
|
|
+
|
|
|
def _serialize_grads(
|
|
|
self,
|
|
|
grads: Sequence[torch.Tensor],
|