|
@@ -2,14 +2,13 @@
|
|
from typing import Tuple, Sequence
|
|
from typing import Tuple, Sequence
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-from hivemind.moe.server.expert_backend import ExpertBackend
|
|
|
|
|
|
+from hivemind.moe.server.module_backend import ModuleBackend
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
|
|
|
|
-from src.bloom.block import BloomBlock
|
|
|
|
from src.server.cache import MemoryCache
|
|
from src.server.cache import MemoryCache
|
|
|
|
|
|
|
|
|
|
-class TransformerBlockBackend(ExpertBackend):
|
|
|
|
|
|
+class TransformerBackend(ModuleBackend):
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
|
|
|
|
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
|
|
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
|
|
@@ -21,12 +20,12 @@ class TransformerBlockBackend(ExpertBackend):
|
|
for name, buf in self.module.named_buffers():
|
|
for name, buf in self.module.named_buffers():
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
-
|
|
|
|
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
|
|
|
|
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
|
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
|
- with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
|
|
|
- return inputs[0] * 2
|
|
|
|
|
|
+ with self.memory_cache.use_cache(attention_cache_handle) as cache:
|
|
|
|
+ cache[...] += 1
|
|
|
|
+ return inputs[0] + cache
|
|
|
|
|
|
def get_pools(self) -> Sequence[TaskPool]:
|
|
def get_pools(self) -> Sequence[TaskPool]:
|
|
return self.forward_pool, self.backward_pool, self.inference_pool
|
|
return self.forward_pool, self.backward_pool, self.inference_pool
|