|
@@ -1,11 +1,8 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-import os
|
|
|
from concurrent.futures import Future
|
|
|
from dataclasses import dataclass
|
|
|
-from queue import Queue
|
|
|
-from threading import Thread
|
|
|
-from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
@@ -14,6 +11,7 @@ from torch.autograd.function import once_differentiable
|
|
|
from hivemind import moe
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.dht import DHT
|
|
|
+from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
@@ -24,7 +22,6 @@ from hivemind.utils import (
|
|
|
nested_compare,
|
|
|
nested_flatten,
|
|
|
nested_pack,
|
|
|
- switch_to_uvloop,
|
|
|
)
|
|
|
from hivemind.utils.mpfuture import MPFuture
|
|
|
from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
|
|
@@ -88,7 +85,7 @@ class RemoteExpert(nn.Module):
|
|
|
@property
|
|
|
def info(self):
|
|
|
if self._rpc_info is None:
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
+ outputs = _RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
return self._rpc_info
|
|
|
|
|
@@ -96,87 +93,45 @@ class RemoteExpert(nn.Module):
|
|
|
return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
|
|
|
|
|
|
|
|
|
-class RemoteExpertWorker:
|
|
|
- """Local thread for managing async tasks related to RemoteExpert"""
|
|
|
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
|
|
|
+ experts: List[Optional[RemoteExpert]] = []
|
|
|
+ for i in infos:
|
|
|
+ if i is not None:
|
|
|
+ experts.append(RemoteExpert(i, p2p))
|
|
|
+ else:
|
|
|
+ experts.append(None)
|
|
|
+ return experts
|
|
|
|
|
|
- _task_queue: Queue = Queue()
|
|
|
- _event_thread: Optional[Thread] = None
|
|
|
- _pid: int = -1
|
|
|
|
|
|
- @classmethod
|
|
|
- def _run(cls):
|
|
|
- loop = switch_to_uvloop()
|
|
|
+def create_remote_experts(
|
|
|
+ infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
|
+) -> Union[List[Optional[RemoteExpert]], Future]:
|
|
|
+ if return_future:
|
|
|
|
|
|
- async def receive_tasks():
|
|
|
- while True:
|
|
|
- cor, future = cls._task_queue.get()
|
|
|
- try:
|
|
|
- result = await cor
|
|
|
- except Exception as e:
|
|
|
- future.set_exception(e)
|
|
|
- continue
|
|
|
- if not future.cancelled():
|
|
|
- future.set_result(result)
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return _create_remote_experts(await infos_future, p2p)
|
|
|
|
|
|
- loop.run_until_complete(receive_tasks())
|
|
|
+ return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
|
|
|
- @classmethod
|
|
|
- def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
|
|
|
- if cls._event_thread is None or cls._pid != os.getpid():
|
|
|
- cls._pid = os.getpid()
|
|
|
- cls._event_thread = Thread(target=cls._run, daemon=True)
|
|
|
- cls._event_thread.start()
|
|
|
+ p2p = _RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
+ return _create_remote_experts(infos, p2p)
|
|
|
|
|
|
- future = Future()
|
|
|
- cls._task_queue.put((coro, future))
|
|
|
|
|
|
- if return_future:
|
|
|
- return future
|
|
|
+def batch_create_remote_experts(
|
|
|
+ infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
|
|
|
+ dht: DHT,
|
|
|
+ return_future: bool = False,
|
|
|
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
|
|
|
+ if return_future:
|
|
|
|
|
|
- result = future.result()
|
|
|
- return result
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return [_create_remote_experts(i, p2p) for i in await infos_future]
|
|
|
|
|
|
- @classmethod
|
|
|
- def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
|
|
|
- experts: List[Optional[RemoteExpert]] = []
|
|
|
- for i in infos:
|
|
|
- if i is not None:
|
|
|
- experts.append(RemoteExpert(i, p2p))
|
|
|
- else:
|
|
|
- experts.append(None)
|
|
|
- return experts
|
|
|
+ return _RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
|
|
|
- @classmethod
|
|
|
- def spawn_experts(
|
|
|
- cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
|
- ) -> Union[List[Optional[RemoteExpert]], Future]:
|
|
|
- if return_future:
|
|
|
-
|
|
|
- async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
- p2p = await dht.replicate_p2p()
|
|
|
- return cls._spawn_experts(await infos_future, p2p)
|
|
|
-
|
|
|
- return cls.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
-
|
|
|
- p2p = cls.run_coroutine(dht.replicate_p2p())
|
|
|
- return cls._spawn_experts(infos, p2p)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def batch_spawn_experts(
|
|
|
- cls,
|
|
|
- infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
|
|
|
- dht: DHT,
|
|
|
- return_future: bool = False,
|
|
|
- ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
|
|
|
- if return_future:
|
|
|
-
|
|
|
- async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
- p2p = await dht.replicate_p2p()
|
|
|
- return [cls._spawn_experts(i, p2p) for i in await infos_future]
|
|
|
-
|
|
|
- return cls.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
-
|
|
|
- return [cls.spawn_experts(exps, dht) for exps in infos]
|
|
|
+ return [create_remote_experts(exps, dht) for exps in infos]
|
|
|
|
|
|
|
|
|
async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
@@ -266,7 +221,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
ctx.uid, ctx.stub, ctx.info = uid, stub, info
|
|
|
ctx.save_for_backward(*inputs)
|
|
|
|
|
|
- deserialized_outputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ deserialized_outputs = _RemoteExpertWorker.run_coroutine(
|
|
|
expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
|
|
|
)
|
|
|
|
|
@@ -279,7 +234,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
|
|
|
- deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
|
|
|
expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
|
|
|
)
|
|
|
|