|
@@ -5,7 +5,7 @@ from concurrent.futures import Future
|
|
|
from dataclasses import dataclass
|
|
|
from queue import Queue
|
|
|
from threading import Thread
|
|
|
-from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
+from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
@@ -20,7 +20,7 @@ from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils import (
|
|
|
MSGPackSerializer,
|
|
|
amap_in_executor,
|
|
|
- as_aiter,
|
|
|
+ iter_as_aiter,
|
|
|
nested_compare,
|
|
|
nested_flatten,
|
|
|
nested_pack,
|
|
@@ -147,50 +147,52 @@ class RemoteExpertWorker:
|
|
|
return experts
|
|
|
|
|
|
@classmethod
|
|
|
- def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
|
|
|
+ def spawn_experts(
|
|
|
+ cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
|
+ ) -> Union[List[Optional[RemoteExpert]], Future]:
|
|
|
+ if return_future:
|
|
|
+
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return cls._spawn_experts(await infos_future, p2p)
|
|
|
+
|
|
|
+ return cls.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
+
|
|
|
p2p = cls.run_coroutine(dht.replicate_p2p())
|
|
|
return cls._spawn_experts(infos, p2p)
|
|
|
|
|
|
@classmethod
|
|
|
- def spawn_experts_future(
|
|
|
- cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
|
|
|
- ) -> Future[List[Optional[RemoteExpert]]]:
|
|
|
- async def _unpack():
|
|
|
- p2p = cls.run_coroutine(dht.replicate_p2p(), True)
|
|
|
- return cls.spawn_experts(await infos, await p2p)
|
|
|
-
|
|
|
- return cls.run_coroutine(_unpack, True)
|
|
|
+ def batch_spawn_experts(
|
|
|
+ cls,
|
|
|
+ infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
|
|
|
+ dht: DHT,
|
|
|
+ return_future: bool = False,
|
|
|
+ ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
|
|
|
+ if return_future:
|
|
|
|
|
|
- @classmethod
|
|
|
- def spawn_experts_bulk(
|
|
|
- cls, infos: Sequence[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
|
|
|
- ) -> List[List[Optional[RemoteExpert]]]:
|
|
|
- return [cls.spawn_experts(exps, dht) for exps in infos]
|
|
|
+ async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ return [cls._spawn_experts(i, p2p) for i in await infos_future]
|
|
|
|
|
|
- @classmethod
|
|
|
- def spawn_experts_bulk_future(
|
|
|
- cls, infos: Future[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
|
|
|
- ) -> Future[List[List[Optional[RemoteExpert]]]]:
|
|
|
- async def _unpack():
|
|
|
- return cls.spawn_experts_bulk(await infos, dht)
|
|
|
+ return cls.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
|
|
|
- return cls.run_coroutine(_unpack, True)
|
|
|
+ return [cls.spawn_experts(exps, dht) for exps in infos]
|
|
|
|
|
|
|
|
|
async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
grad_inputs = await stub.rpc_backward_stream(
|
|
|
amap_in_executor(
|
|
|
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
- as_aiter(*split),
|
|
|
+ iter_as_aiter(split),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
|
|
|
|
|
|
-async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
)
|
|
@@ -210,23 +212,23 @@ async def expert_backward(
|
|
|
if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
return await _backward_stream(uid, serialized_tensors, stub)
|
|
|
else:
|
|
|
- return await _backward(uid, serialized_tensors, stub)
|
|
|
+ return await _backward_unary(uid, serialized_tensors, stub)
|
|
|
|
|
|
|
|
|
async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
outputs = await stub.rpc_forward_stream(
|
|
|
amap_in_executor(
|
|
|
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
|
|
|
- as_aiter(*split),
|
|
|
+ iter_as_aiter(split),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
|
|
|
|
|
|
-async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
)
|
|
@@ -243,7 +245,7 @@ async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions:
|
|
|
if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
return await _forward_stream(uid, serialized_tensors, stub)
|
|
|
else:
|
|
|
- return await _forward(uid, serialized_tensors, stub)
|
|
|
+ return await _forward_unary(uid, serialized_tensors, stub)
|
|
|
|
|
|
|
|
|
class _RemoteModuleCall(torch.autograd.Function):
|