Selaa lähdekoodia

actually merge

Artem Chumachenko 3 vuotta sitten
vanhempi
commit
85c2ec7b06
2 muutettua tiedostoa jossa 8 lisäystä ja 12 poistoa
  1. 1 1
      src/client/inference_session.py
  2. 7 11
      src/server/backend.py

+ 1 - 1
src/client/inference_session.py

@@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
                 break  # this message means "done sending"
 
     def step(self, new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        """Inference step: send a chunk of input tesors and receive a chunk of outputs"""
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue

+ 7 - 11
src/server/backend.py

@@ -1,13 +1,9 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-<<<<<<< HEAD
-from typing import Sequence, Tuple, Dict, Any, Optional
-=======
-from typing import Sequence, Tuple, Dict, Any
->>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
-from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import InvalidStateError, get_logger
@@ -18,7 +14,6 @@ from src.server.cache import MemoryCache
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-<<<<<<< HEAD
 
 class InferenceTaskPool(TaskPool):
     def __init__(self, *args, **kwargs):
@@ -42,9 +37,6 @@ class InferenceTaskPool(TaskPool):
                     yield [task]
             except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
-=======
-MAX_LENGTH = 2048
->>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
 
 
 class InferenceTaskPool(TaskPool):
@@ -100,7 +92,11 @@ class TransformerBackend(ModuleBackend):
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states, hypo_ids, prompts = inputs  # todo: in future, it would be best to support attention mask here
+            (
+                hidden_states,
+                hypo_ids,
+                prompts,
+            ) = inputs  # todo: in future, it would be best to support attention mask here
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"