Browse Source

swap to int64 (rationale: pytorch does not support uint64)

justheuristic 3 years ago
parent
commit
8092bd31ff
5 changed files with 29 additions and 11 deletions
  1. 1 2
      src/client/inference_chain.py
  2. 18 5
      src/client/remote_block.py
  3. 6 1
      src/server/backend.py
  4. 2 2
      src/server/cache.py
  5. 2 1
      src/server/handler.py

+ 1 - 2
src/client/inference_chain.py

@@ -6,8 +6,7 @@ from hivemind import DHT
 from torch import nn
 
 from src import DistributedBloomConfig
-
-MAX_LENGTH = 128  #TODO un-hardcode
+from src.server.backend import MAX_LENGTH
 
 
 class RemoteInferenceChain(nn.Module):

+ 18 - 5
src/client/remote_block.py

@@ -2,6 +2,7 @@ from concurrent.futures import Future
 from functools import partial
 from typing import List, Optional, Union, Sequence
 
+import torch
 from hivemind.moe.client import RemoteExpert
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertUID
@@ -10,18 +11,30 @@ from hivemind.p2p import StubBase, P2P
 from hivemind.proto.runtime_pb2 import ExpertInfo
 from hivemind.dht import DHT
 from hivemind.utils import MPFuture, DHTExpiration
+
+from src import DistributedBloomConfig
+from src.server.backend import MAX_LENGTH
 from src.server.handler import TransformerConnectionHandler
 
 
-class RemoteTransformerBlock(RemoteExpert):
+class RemoteTransformerBlockSession(RemoteExpert):
+    """A class that interacts with a specific remote server for forward/backward or inference"""
+
+    def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
+        super().__init__(info, p2p)
+        self._config = config
+        self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
+        self._active_stream: Optional[RemoteTransformerStream] = None
+
     @property
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
 
+
 def get_remote_module(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
+) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -35,7 +48,7 @@ def get_remote_module(
 
 def create_remote_module(
     infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
+) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
     if return_future:
 
         async def _unpack(infos_future: MPFuture, dht: DHT):
@@ -48,10 +61,10 @@ def create_remote_module(
 
 
 def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
-    experts: List[Optional[RemoteTransformerBlock]] = []
+    experts: List[Optional[RemoteTransformerBlockSession]] = []
     for info in infos:
         if info is not None:
-            experts.append(RemoteTransformerBlock(info, p2p))
+            experts.append(RemoteTransformerBlockSession(info, p2p))
         else:
             experts.append(None)
     return experts

+ 6 - 1
src/server/backend.py

@@ -7,6 +7,8 @@ from hivemind.moe.server.task_pool import TaskPool
 
 from src.server.cache import MemoryCache
 
+MAX_LENGTH = 2048
+
 
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
@@ -22,7 +24,10 @@ class TransformerBackend(ModuleBackend):
 
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
-    def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
+    def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: torch.IntTensor) -> Tuple[torch.Tensor, ...]:
+
+        attention_cache_handle = int(attention_cache_handle.item())
+        print('HANDLE:', attention_cache_handle)
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
             cache[...] += 1
             return inputs[0] + cache

+ 2 - 2
src/server/cache.py

@@ -35,8 +35,8 @@ class MemoryCache:
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.device = device
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
-        self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
-        self._handle_counter = mp.Value(ctypes.c_uint64, 0, lock=False)
+        self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
         self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
         self.runtime_pid = os.getpid()

+ 2 - 1
src/server/handler.py

@@ -26,7 +26,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert isinstance(backend, TransformerBackend)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)):
+        async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)) as handle:
+            inputs.append(torch.tensor([handle], dtype=torch.int64))
             outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
 
         yield runtime_pb2.ExpertResponse(tensors=outputs)