|
@@ -1,7 +1,6 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
-import contextlib
|
|
|
import threading
|
|
|
-from typing import AsyncIterator, Tuple, List, Dict, Optional
|
|
|
+from typing import AsyncIterator, Tuple, Optional
|
|
|
|
|
|
import torch
|
|
|
from hivemind import P2PContext, DHT
|
|
@@ -13,6 +12,8 @@ from hivemind.moe.server.server import Server
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from torch import nn
|
|
|
|
|
|
+from src.node.cache import AttentionCache
|
|
|
+
|
|
|
|
|
|
class BloomServer(Server):
|
|
|
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
|
@@ -62,36 +63,7 @@ class _BloomBlockBackend(ExpertBackend):
|
|
|
with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
|
|
raise NotImplementedError("TODO")
|
|
|
|
|
|
-
|
|
|
-class AttentionCache:
|
|
|
- lock: mp.Lock
|
|
|
- data: Dict[int, SomeKindOfTupleWithTensors] # workaround for now, while we are on CPU
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> int:
|
|
|
- """
|
|
|
- Allocate buffers for attention cache on the compute device, return a unique handle;
|
|
|
- This function should be called by connection handler processes, may be called concurrently
|
|
|
- """
|
|
|
- try:
|
|
|
- async with acquire_asynchronously(self.lock):
|
|
|
- handle: int = generate_unique_handle() # or just use counter mpvalue and increment each time
|
|
|
- assert handle not in data
|
|
|
- self.data[handle] = todo_allocate(self, size, dtype)
|
|
|
- yield handle
|
|
|
- finally:
|
|
|
- todo_deallocate(self, handle)
|
|
|
- # ^-- this should NOT move any data. But it may mark data for movement during next allocation
|
|
|
- self.data.pop(handle, None);
|
|
|
-
|
|
|
- def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
|
|
|
- """Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
|
|
|
- with self.lock:
|
|
|
- yield self.data[handle]
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
# later:
|
|
|
-# - if possible, do not change how DHTHandler handles for now
|
|
|
# - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
|
|
|
# - contiguous attention cache with max size
|
|
|
# - select a subset of experts
|