Ver Fonte

RemoteTransformerBlock

justheuristic há 3 anos atrás
pai
commit
3e9fd63a02

+ 1 - 0
src/client/__init__.py

@@ -0,0 +1 @@
+from hivemind.moe.client import RemoteExpert

+ 58 - 0
src/client/remote_block.py

@@ -0,0 +1,58 @@
+from concurrent.futures import Future
+from functools import partial
+from typing import List, Optional, Union, Sequence
+
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertUID
+from hivemind.moe.server.dht_handler import _get_experts
+from hivemind.p2p import StubBase, P2P
+from hivemind.proto.runtime_pb2 import ExpertInfo
+from hivemind.dht import DHTExpiration, DHT
+from hivemind.utils import MPFuture
+from src.server.handler import TransformerConnectionHandler
+
+
+class RemoteTransformerBlock(RemoteExpert):
+    @property
+    def stub(self) -> StubBase:
+        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
+
+
+def get_remote_module(
+    dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
+) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
+    """
+    :param uids: find experts with these ids from across the DHT
+    :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock if found else None]
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    return create_remote_module(result, dht, return_future)
+
+
+def create_remote_module(
+    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
+    if return_future:
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
+
+
+def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
+    experts: List[Optional[RemoteTransformerBlock]] = []
+    for info in infos:
+        if info is not None:
+            experts.append(RemoteTransformerBlock(info, p2p))
+        else:
+            experts.append(None)
+    return experts
+
+
+

+ 4 - 21
src/server/backend.py

@@ -19,32 +19,15 @@ from src.server.cache import MemoryCache
 
 class BloomBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
-    def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
-        object().__init__()  # to bypass super.__init__
-        self.name, self.module = name, module
+    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
+        super().__init__(*args, **kwargs)  # to bypass super.__init__
         self.memory_cache = memory_cache
 
-        for name, param in module.named_parameters():
+        for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
-        for name, buf in module.named_buffers():
+        for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.args_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
-        self.kwargs_schema = {}
-        self.outputs_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
-
-        self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
-        self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
-
-        self.grad_inputs_schema = self.forward_schema  # outputs from backward have same shape as inputs for forward
-        self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
-        self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
-
-    @property
-    def expert(self):
-        #TODO un-hardcode this naming from hivemind
-        return self.module
-
     def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
         with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
             raise NotImplementedError("TODO")

+ 1 - 1
src/server/handler.py

@@ -7,7 +7,7 @@ from hivemind.proto import runtime_pb2
 from src.bloom.block import BloomBlock
 
 
-class BloomConnectionHandler(ConnectionHandler):
+class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
     def __init__(self, *args, **kwargs):

+ 8 - 7
src/server/server.py

@@ -15,7 +15,7 @@ from src import DistributedBloomConfig
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
 from src.server.backend import BloomBlockBackend
-from src.server.handler import BloomConnectionHandler
+from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -31,7 +31,7 @@ class Server(threading.Thread):
     ):
         threading.Thread.__init__(self)
         self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
-        self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
@@ -105,15 +105,17 @@ class Server(threading.Thread):
         blocks = {}
         for i in range(num_blocks):
             module_uid = f"dummy_block.{i}"
-            HARDCODCED_LENGTH = 2048
+            block = BloomBlock(block_config, layer_number=i)
+            for param in block.parameters():
+                param.requires_grad = False
 
             blocks[module_uid] = BloomBlockBackend(
                 module_uid,
-                BloomBlock(block_config, layer_number=i),
+                block,
                 memory_cache=memory_cache,
-                args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
+                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
                 kwargs_schema={},
-                outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
+                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )
@@ -121,7 +123,6 @@ class Server(threading.Thread):
         return cls(
             dht,
             blocks,
-            cache_size_bytes=cache_size_bytes,
             num_connection_handlers=num_handlers,
             device=device,
             stats_report_interval=stats_report_interval,