Przeglądaj źródła

expel all bloom-specific files to src.bloom

justheuristic 3 lat temu
rodzic
commit
e5e8c9ed12

+ 1 - 0
src/__init__.py

@@ -0,0 +1 @@
+from .bloom import *

+ 1 - 0
src/bloom/__init__.py

@@ -0,0 +1 @@
+from src.bloom.model import BloomModel, BloomForCausalLM, MemoryEfficientBloomConfig

+ 1 - 1
src/block.py → src/bloom/block.py

@@ -9,7 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.ops import (
+from src.bloom.ops import (
     BloomGelu,
     BloomScaledSoftmax,
     attention_mask_func,

+ 2 - 2
src/model.py → src/bloom/model.py

@@ -20,8 +20,8 @@ from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
 from transformers.utils import logging
 
-from src.block import BloomBlock
-from src.ops import build_alibi_tensor
+from src.bloom.block import BloomBlock
+from src.bloom.ops import build_alibi_tensor
 
 logger = logging.get_logger(__name__)
 

+ 0 - 0
src/ops.py → src/bloom/ops.py


+ 0 - 0
src/node/__init__.py


+ 3 - 31
src/backend.py → src/node/backend.py

@@ -1,7 +1,6 @@
 """Code for serving bloom blocks via hivemind-server"""
-import contextlib
 import threading
-from typing import AsyncIterator, Tuple, List, Dict, Optional
+from typing import AsyncIterator, Tuple, Optional
 
 import torch
 from hivemind import P2PContext, DHT
@@ -13,6 +12,8 @@ from hivemind.moe.server.server import Server
 from hivemind.proto import runtime_pb2
 from torch import nn
 
+from src.node.cache import AttentionCache
+
 
 class BloomServer(Server):
     """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
@@ -62,36 +63,7 @@ class _BloomBlockBackend(ExpertBackend):
         with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
             raise NotImplementedError("TODO")
 
-
-class AttentionCache:
-    lock: mp.Lock
-    data: Dict[int, SomeKindOfTupleWithTensors]  # workaround for now, while we are on CPU
-    @contextlib.asynccontextmanager
-    async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> int:
-        """
-        Allocate buffers for attention cache on the compute device, return a unique handle;
-        This function should be called by connection handler processes, may be called concurrently
-        """
-        try:
-            async with acquire_asynchronously(self.lock):
-                handle: int = generate_unique_handle() # or just use  counter mpvalue and increment each time
-                assert handle not in data
-                self.data[handle] = todo_allocate(self, size, dtype)
-            yield handle
-        finally:
-            todo_deallocate(self, handle)
-            # ^-- this should NOT move any data. But it may mark data for movement during next allocation
-            self.data.pop(handle, None);
-
-    def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
-        """Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
-        with self.lock:
-            yield self.data[handle]
-
-
-
 # later:
-# - if possible, do not change how DHTHandler handles for now
 # - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
 # - contiguous attention cache with max size
 # - select a subset of experts

+ 43 - 0
src/node/cache.py

@@ -0,0 +1,43 @@
+import contextlib
+import ctypes
+import multiprocessing as mp
+from typing import Dict, Tuple
+
+import torch
+
+
+class MemoryCache:
+    lock: mp.Lock
+    runtime_pid: int
+    handle_counter: mp.Value[ctypes.c_uint64]
+    current_size: mp.Value[ctypes.c_uint64]
+    _runtime_data: Dict[int, SomeKindOfTupleWithTensors]  # workaround for now, while we are on CPU
+
+    @contextlib.asynccontextmanager
+    async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> Optional[int]:
+        """
+        Allocate buffers for attention cache on the compute device, return a unique handle;
+        This function should be called by connection handler processes, may be called concurrently
+        """
+        assert os.getpid() != self.runtime_pid
+        try:
+            async with acquire_asynchronously(self.lock):
+                check_and_update_size(current_size, size, dtype)
+                if enough_space:
+                    self.handle_counter.value += 1
+                    handle = int(self.handle_counter.value)
+                    # note: you cannot allocate data here because this is
+                    TODO_SOMEHOW_COMUNICATE_WITH_RUNTIME_TO_CREATE_THE_RIGHT_DATA
+            yield handle
+        finally:
+            todo_deallocate(self, handle)
+            # ^-- this should NOT move any data. But it may mark data for movement during next allocation
+            self.data.pop(handle, None);
+
+    def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
+        """Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
+        assert os.getpid() == self.runtime_pid
+        with self.lock:
+            if first_time:
+                allocate_stuff(self._runtime_data)
+            yield self.data[handle]