Browse Source

review issues fix

Pavel Samygin 3 năm trước cách đây
mục cha
commit
56c608933b

+ 1 - 4
benchmarks/benchmark_throughput.py

@@ -247,10 +247,7 @@ if __name__ == "__main__":
         )
     elif args.preset == "minimalistic":
         benchmark_throughput(
-            num_experts=1,
-            num_clients=1,
-            num_handlers=1,
-            num_batches_per_client=args.num_batches_per_client,
+            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

+ 3 - 9
hivemind/moe/client/beam_search.py

@@ -18,8 +18,7 @@ from hivemind.moe.server.expert_uid import (
     is_valid_prefix,
 )
 from hivemind.p2p import PeerInfo
-from hivemind.utils import get_dht_time, get_logger
-from hivemind.utils.mpfuture import MPFuture
+from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -260,10 +259,7 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        if return_future:
-            return RemoteExpertWorker.spawn_experts_future(result, self.dht)
-
-        return RemoteExpertWorker.spawn_experts(result, self.dht)
+        return RemoteExpertWorker.spawn_experts(result, self.dht, return_future)
 
     @classmethod
     async def _find_best_experts(
@@ -386,9 +382,7 @@ class MoEBeamSearcher:
             return_future,
         )
 
-        if return_future:
-            return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
-        return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)
+        return RemoteExpertWorker.batch_spawn_experts(result, self.dht, return_future)
 
     @classmethod
     async def _batch_find_best_experts(

+ 33 - 31
hivemind/moe/client/expert.py

@@ -5,7 +5,7 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from queue import Queue
 from threading import Thread
-from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -20,7 +20,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils import (
     MSGPackSerializer,
     amap_in_executor,
-    as_aiter,
+    iter_as_aiter,
     nested_compare,
     nested_flatten,
     nested_pack,
@@ -147,50 +147,52 @@ class RemoteExpertWorker:
         return experts
 
     @classmethod
-    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
+    def spawn_experts(
+        cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+    ) -> Union[List[Optional[RemoteExpert]], Future]:
+        if return_future:
+
+            async def _unpack(infos_future: MPFuture, dht: DHT):
+                p2p = await dht.replicate_p2p()
+                return cls._spawn_experts(await infos_future, p2p)
+
+            return cls.run_coroutine(_unpack(infos, dht), return_future)
+
         p2p = cls.run_coroutine(dht.replicate_p2p())
         return cls._spawn_experts(infos, p2p)
 
     @classmethod
-    def spawn_experts_future(
-        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> Future[List[Optional[RemoteExpert]]]:
-        async def _unpack():
-            p2p = cls.run_coroutine(dht.replicate_p2p(), True)
-            return cls.spawn_experts(await infos, await p2p)
-
-        return cls.run_coroutine(_unpack, True)
+    def batch_spawn_experts(
+        cls,
+        infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+        dht: DHT,
+        return_future: bool = False,
+    ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+        if return_future:
 
-    @classmethod
-    def spawn_experts_bulk(
-        cls, infos: Sequence[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
-    ) -> List[List[Optional[RemoteExpert]]]:
-        return [cls.spawn_experts(exps, dht) for exps in infos]
+            async def _unpack(infos_future: MPFuture, dht: DHT):
+                p2p = await dht.replicate_p2p()
+                return [cls._spawn_experts(i, p2p) for i in await infos_future]
 
-    @classmethod
-    def spawn_experts_bulk_future(
-        cls, infos: Future[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
-    ) -> Future[List[List[Optional[RemoteExpert]]]]:
-        async def _unpack():
-            return cls.spawn_experts_bulk(await infos, dht)
+            return cls.run_coroutine(_unpack(infos, dht), return_future)
 
-        return cls.run_coroutine(_unpack, True)
+        return [cls.spawn_experts(exps, dht) for exps in infos]
 
 
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
     grad_inputs = await stub.rpc_backward_stream(
         amap_in_executor(
             lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
-            as_aiter(*split),
+            iter_as_aiter(split),
         ),
     )
 
     return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
 
 
-async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
     )
@@ -210,23 +212,23 @@ async def expert_backward(
         if size >= DEFAULT_MAX_MSG_SIZE:
             return await _backward_stream(uid, serialized_tensors, stub)
     else:
-        return await _backward(uid, serialized_tensors, stub)
+        return await _backward_unary(uid, serialized_tensors, stub)
 
 
 async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
     outputs = await stub.rpc_forward_stream(
         amap_in_executor(
             lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
-            as_aiter(*split),
+            iter_as_aiter(split),
         ),
     )
 
     return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
 
 
-async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
     )
@@ -243,7 +245,7 @@ async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions:
         if size >= DEFAULT_MAX_MSG_SIZE:
             return await _forward_stream(uid, serialized_tensors, stub)
     else:
-        return await _forward(uid, serialized_tensors, stub)
+        return await _forward_unary(uid, serialized_tensors, stub)
 
 
 class _RemoteModuleCall(torch.autograd.Function):

+ 1 - 3
hivemind/moe/server/dht_handler.py

@@ -84,9 +84,7 @@ def get_experts(
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-    if return_future:
-        return RemoteExpertWorker.spawn_experts_future(result, dht)
-    return RemoteExpertWorker.spawn_experts(result, dht)
+    return RemoteExpertWorker.spawn_experts(result, dht, return_future)
 
 
 async def _get_experts(

+ 3 - 4
hivemind/moe/server/server.py

@@ -27,7 +27,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 logger = get_logger(__name__)
 
@@ -37,10 +37,9 @@ class Server(threading.Thread):
     Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
-    A working server does 3 things:
+    A working server does two things:
      - processes incoming forward/backward requests via Runtime (created by the server)
      - publishes updates to expert status every :update_period: seconds
-     - follows orders from HivemindController - if it exists
 
     :type dht: DHT.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
@@ -182,7 +181,7 @@ class Server(threading.Thread):
         optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
-        sample_input = name_to_input[expert_cls](3, hidden_dim)
+        sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
         if isinstance(sample_input, tuple):
             args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
         else:

+ 7 - 1
hivemind/utils/asyncio.py

@@ -2,7 +2,7 @@ import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -29,6 +29,12 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
+async def iter_as_aiter(iterable: Iterable[T]) -> AsyncIterator[T]:
+    """create an asynchronous iterator from single iterable"""
+    for elem in iterable:
+        yield elem
+
+
 async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args: