Artem Chumachenko 3 jaren geleden
bovenliggende
commit
d351431e95

+ 4 - 4
src/client/remote_generation.py

@@ -1,10 +1,10 @@
+from typing import List, Optional
+
 import torch
 import torch.nn.functional as F
 
-from typing import List, Optional
-
-from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, TopKAlgorithm, NucleusAlgorithm
-from src.utils.generation_constraints import ABCBloomConstraint, MaxNewTokensConstraint, EosConstraint
+from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
 
 
 class RemoteGenerationMixin:

+ 1 - 3
src/client/remote_model.py

@@ -1,9 +1,7 @@
 # this code is in active development, interfaces may change
 import os
-import torch
 from typing import List, Optional, Tuple, Union
 
-import torch
 import hivemind
 import torch
 import torch.nn as nn
@@ -17,8 +15,8 @@ from src.bloom.model import (
     BloomPreTrainedModel,
     LMHead,
 )
-from src.client.remote_sequential import RemoteSequential
 from src.client.remote_generation import RemoteGenerationMixin
+from src.client.remote_sequential import RemoteSequential
 from src.utils.generation_algorithms import DecodingAlgorithm
 from src.utils.generation_constraints import ABCBloomConstraint
 

+ 1 - 1
src/server/backend.py

@@ -23,7 +23,7 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=4096, name=f"{self.name}_inference")
+        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 1 - 1
src/server/cache.py

@@ -25,7 +25,7 @@ class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
-        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.device = device
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)

+ 2 - 2
src/utils/generation_algorithms.py

@@ -1,8 +1,8 @@
-import torch
-
 from abc import ABC
 from typing import Tuple
 
+import torch
+
 TokenIds = torch.Tensor
 HypoIds = torch.Tensor
 

+ 2 - 2
src/utils/generation_constraints.py

@@ -1,7 +1,7 @@
-import torch
-
 from abc import ABC
 
+import torch
+
 
 class ABCBloomConstraint(ABC):
     """