|
@@ -1,13 +1,9 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
from queue import Empty
|
|
|
-<<<<<<< HEAD
|
|
|
-from typing import Sequence, Tuple, Dict, Any, Optional
|
|
|
-=======
|
|
|
-from typing import Sequence, Tuple, Dict, Any
|
|
|
->>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
|
|
|
+from typing import Any, Dict, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
|
|
|
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
|
from hivemind.utils import InvalidStateError, get_logger
|
|
@@ -18,7 +14,6 @@ from src.server.cache import MemoryCache
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
-<<<<<<< HEAD
|
|
|
|
|
|
class InferenceTaskPool(TaskPool):
|
|
|
def __init__(self, *args, **kwargs):
|
|
@@ -42,9 +37,6 @@ class InferenceTaskPool(TaskPool):
|
|
|
yield [task]
|
|
|
except InvalidStateError as e:
|
|
|
logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
-=======
|
|
|
-MAX_LENGTH = 2048
|
|
|
->>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
|
|
|
|
|
|
|
|
|
class InferenceTaskPool(TaskPool):
|
|
@@ -100,7 +92,11 @@ class TransformerBackend(ModuleBackend):
|
|
|
with torch.inference_mode():
|
|
|
attention_cache_handle = int(cache_metadata[0, 0].item())
|
|
|
prefix_length = int(cache_metadata[0, 1].item())
|
|
|
- hidden_states, hypo_ids, prompts = inputs # todo: in future, it would be best to support attention mask here
|
|
|
+ (
|
|
|
+ hidden_states,
|
|
|
+ hypo_ids,
|
|
|
+ prompts,
|
|
|
+ ) = inputs # todo: in future, it would be best to support attention mask here
|
|
|
assert (
|
|
|
hidden_states.ndim == 3
|
|
|
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|