Bladeren bron

rpc_inference works!

justheuristic 3 jaren geleden
bovenliggende
commit
fee63bd440
3 gewijzigde bestanden met toevoegingen van 21 en 17 verwijderingen
  1. 9 11
      src/client/remote_block.py
  2. 5 2
      src/server/backend.py
  3. 7 4
      src/server/handler.py

+ 9 - 11
src/client/remote_block.py

@@ -12,19 +12,17 @@ from hivemind.proto.runtime_pb2 import ExpertInfo
 from hivemind.dht import DHT
 from hivemind.utils import MPFuture, DHTExpiration
 
-from src import DistributedBloomConfig
-from src.server.backend import MAX_LENGTH
 from src.server.handler import TransformerConnectionHandler
 
 
-class RemoteTransformerBlockSession(RemoteExpert):
+class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a specific remote server for forward/backward or inference"""
 
-    def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
+    def __init__(self, info: ExpertInfo, p2p: P2P):
         super().__init__(info, p2p)
-        self._config = config
-        self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
-        self._active_stream: Optional[RemoteTransformerStream] = None
+        # self._config = config
+        # self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
+        # self._active_stream: Optional[RemoteTransformerStream] = None
 
     @property
     def stub(self) -> StubBase:
@@ -34,7 +32,7 @@ class RemoteTransformerBlockSession(RemoteExpert):
 
 def get_remote_module(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
+) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -48,7 +46,7 @@ def get_remote_module(
 
 def create_remote_module(
     infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
+) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
     if return_future:
 
         async def _unpack(infos_future: MPFuture, dht: DHT):
@@ -61,10 +59,10 @@ def create_remote_module(
 
 
 def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
-    experts: List[Optional[RemoteTransformerBlockSession]] = []
+    experts: List[Optional[RemoteTransformerBlock]] = []
     for info in infos:
         if info is not None:
-            experts.append(RemoteTransformerBlockSession(info, p2p))
+            experts.append(RemoteTransformerBlock(info, p2p))
         else:
             experts.append(None)
     return experts

+ 5 - 2
src/server/backend.py

@@ -5,6 +5,7 @@ import torch
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 
+from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
 
 MAX_LENGTH = 2048
@@ -15,6 +16,7 @@ class TransformerBackend(ModuleBackend):
 
     def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
         super().__init__(*args, **kwargs)
+        assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
 
         for name, param in self.module.named_parameters():
@@ -24,13 +26,14 @@ class TransformerBackend(ModuleBackend):
 
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
-    def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: torch.IntTensor) -> Tuple[torch.Tensor, ...]:
+    def inference_step(self, attention_cache_handle: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
 
         attention_cache_handle = int(attention_cache_handle.item())
         print('HANDLE:', attention_cache_handle)
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
+            print(inputs[0].shape, cache.shape)
             cache[...] += 1
-            return inputs[0] + cache
+            return (inputs[0] + cache,)
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool

+ 7 - 4
src/server/handler.py

@@ -1,12 +1,12 @@
 from typing import AsyncIterator, Dict
 
 import torch
-from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
+from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, nested_flatten
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
 
-from src.server.backend import TransformerBackend
+from src.server.backend import TransformerBackend, MAX_LENGTH
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -26,8 +26,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert isinstance(backend, TransformerBackend)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)) as handle:
-            inputs.append(torch.tensor([handle], dtype=torch.int64))
+
+        hidden_size = backend.module.hidden_size
+        cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
+        async with backend.memory_cache.allocate_cache(cache_descriptor) as handle:
+            inputs.insert(0, torch.tensor([handle], dtype=torch.int64))
             outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
 
         yield runtime_pb2.ExpertResponse(tensors=outputs)