瀏覽代碼

Rename `src` => `petals`

Aleksandr Borzunov 2 年之前
父節點
當前提交
d1c35b4f5f
共有 77 個文件被更改,包括 5021 次插入26 次删除
  1. 1 1
      README.md
  2. 3 3
      cli/convert_model.py
  3. 3 3
      cli/inference_one_block.py
  4. 2 2
      cli/run_server.py
  5. 1 1
      examples/prompt-tuning-personachat.ipynb
  6. 1 1
      examples/prompt-tuning-sst2.ipynb
  7. 5 0
      petals/__init__.py
  8. 2 0
      petals/bloom/__init__.py
  9. 255 0
      petals/bloom/block.py
  10. 86 0
      petals/bloom/from_pretrained.py
  11. 583 0
      petals/bloom/model.py
  12. 0 0
      petals/bloom/ops.py
  13. 5 0
      petals/client/__init__.py
  14. 322 0
      petals/client/inference_session.py
  15. 154 0
      petals/client/remote_forward_backward.py
  16. 328 0
      petals/client/remote_generation.py
  17. 198 0
      petals/client/remote_model.py
  18. 103 0
      petals/client/remote_sequential.py
  19. 167 0
      petals/client/sequence_manager.py
  20. 236 0
      petals/client/sequential_autograd.py
  21. 0 0
      petals/client/spending_policy.py
  22. 0 0
      petals/constants.py
  23. 0 0
      petals/data_structures.py
  24. 180 0
      petals/dht_utils.py
  25. 0 0
      petals/server/__init__.py
  26. 87 0
      petals/server/backend.py
  27. 115 0
      petals/server/block_selection.py
  28. 0 0
      petals/server/cache.py
  29. 470 0
      petals/server/handler.py
  30. 0 0
      petals/server/runtime.py
  31. 499 0
      petals/server/server.py
  32. 0 0
      petals/server/task_pool.py
  33. 0 0
      petals/server/task_prioritizer.py
  34. 127 0
      petals/server/throughput.py
  35. 0 0
      petals/src/__init__.py
  36. 0 0
      petals/src/bloom/__init__.py
  37. 0 0
      petals/src/bloom/block.py
  38. 0 0
      petals/src/bloom/from_pretrained.py
  39. 0 0
      petals/src/bloom/model.py
  40. 246 0
      petals/src/bloom/ops.py
  41. 0 0
      petals/src/client/__init__.py
  42. 0 0
      petals/src/client/inference_session.py
  43. 0 0
      petals/src/client/remote_forward_backward.py
  44. 0 0
      petals/src/client/remote_generation.py
  45. 0 0
      petals/src/client/remote_model.py
  46. 0 0
      petals/src/client/remote_sequential.py
  47. 0 0
      petals/src/client/sequence_manager.py
  48. 0 0
      petals/src/client/sequential_autograd.py
  49. 14 0
      petals/src/client/spending_policy.py
  50. 8 0
      petals/src/constants.py
  51. 41 0
      petals/src/data_structures.py
  52. 0 0
      petals/src/dht_utils.py
  53. 0 0
      petals/src/server/__init__.py
  54. 0 0
      petals/src/server/backend.py
  55. 0 0
      petals/src/server/block_selection.py
  56. 148 0
      petals/src/server/cache.py
  57. 0 0
      petals/src/server/handler.py
  58. 198 0
      petals/src/server/runtime.py
  59. 0 0
      petals/src/server/server.py
  60. 178 0
      petals/src/server/task_pool.py
  61. 20 0
      petals/src/server/task_prioritizer.py
  62. 0 0
      petals/src/server/throughput.py
  63. 0 0
      petals/src/utils/__init__.py
  64. 0 0
      petals/src/utils/convert_8bit.py
  65. 0 0
      petals/src/utils/generation_algorithms.py
  66. 0 0
      petals/src/utils/generation_constraints.py
  67. 0 0
      petals/src/utils/misc.py
  68. 0 0
      petals/utils/__init__.py
  69. 41 0
      petals/utils/convert_8bit.py
  70. 121 0
      petals/utils/generation_algorithms.py
  71. 51 0
      petals/utils/generation_constraints.py
  72. 7 0
      petals/utils/misc.py
  73. 5 5
      tests/test_block_exact_match.py
  74. 3 3
      tests/test_chained_calls.py
  75. 2 2
      tests/test_full_model.py
  76. 2 2
      tests/test_priority_pool.py
  77. 3 3
      tests/test_remote_sequential.py

+ 1 - 1
README.md

@@ -140,7 +140,7 @@ Once your have enough servers, you can use them to train and/or inference the mo
 ```python
 import torch
 import torch.nn.functional as F
-import transformers
+from petals.ansformers
 from src import DistributedBloomForCausalLM
 
 initial_peers = [TODO_put_one_or_more_server_addresses_here]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]

+ 3 - 3
cli/convert_model.py

@@ -9,9 +9,9 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
 from tqdm.auto import tqdm
 
-from src import BloomModel
-from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
-from src.client import DistributedBloomConfig
+from petals.import BloomModel
+from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
+from petals.client import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 3 - 3
cli/inference_one_block.py

@@ -4,9 +4,9 @@ import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tqdm.auto import trange
 
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig
-from src.bloom.ops import build_alibi_tensor
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig
+from petals.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 2 - 2
cli/run_server.py

@@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from humanfriendly import parse_size
 
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.server.server import Server
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
examples/prompt-tuning-personachat.ipynb

@@ -74,7 +74,7 @@
     "from transformers import get_scheduler\n",
     "\n",
     "# Import a Petals model\n",
-    "from src.client.remote_model import DistributedBloomForCausalLM"
+    "from petals.client.remote_model import DistributedBloomForCausalLM"
    ]
   },
   {

+ 1 - 1
examples/prompt-tuning-sst2.ipynb

@@ -74,7 +74,7 @@
     "from transformers import get_scheduler\n",
     "\n",
     "# Import a Petals model\n",
-    "from src.client.remote_model import DistributedBloomForSequenceClassification"
+    "from petals.client.remote_model import DistributedBloomForSequenceClassification"
    ]
   },
   {

+ 5 - 0
petals/__init__.py

@@ -0,0 +1,5 @@
+from petals.bloom import *
+from petals.client import *
+from petals.dht_utils import declare_active_modules, get_remote_module
+
+__version__ = "1.0alpha1"

+ 2 - 0
petals/bloom/__init__.py

@@ -0,0 +1,2 @@
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

+ 255 - 0
petals/bloom/block.py

@@ -0,0 +1,255 @@
+"""
+Bloom intermediate layer
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.quantized.dynamic.modules.linear
+
+from petals.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
+
+
+class BloomAttention(nn.Module):
+    def __init__(self, config, layer_number=None):
+        super().__init__()
+
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.n_head
+        self.head_dim = self.hidden_size // self.num_heads
+        self.split_size = self.hidden_size
+        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+        self.masked_softmax_fusion = config.masked_softmax_fusion
+        self.hidden_dropout = config.hidden_dropout
+
+        if self.head_dim * self.num_heads != self.hidden_size:
+            raise ValueError(
+                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+
+        # Layer-wise attention scaling
+        self.layer_number = max(1, layer_number)
+        self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
+
+        # Scaled Softmax
+        self.scale_mask_softmax = BloomScaledSoftmax(
+            self.masked_softmax_fusion,
+            attention_mask_func,
+            self.attention_softmax_in_fp32,
+            self.layer_number,
+        )
+
+        self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+
+        self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+    def forward(
+        self,
+        hidden_states,
+        residual,
+        layer_past=None,
+        attention_mask=None,
+        alibi=None,
+        head_mask=None,
+        use_cache=False,
+        output_attentions=False,
+    ):
+        if alibi is None:
+            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
+            alibi = build_alibi_tensor(
+                current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
+            )
+
+        # hidden_states: [batch_size, seq_length, hidden_size]
+        # apply preprocessing if the input is padded
+        if attention_mask is not None:
+            alibi = pre_process_alibi_for_pad(alibi, attention_mask)
+        # otherwise repeat alibi tensor with the batch size
+        else:
+            alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
+
+        mixed_x_layer = self.query_key_value(hidden_states)
+
+        # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
+        new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+        # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3  [batch_size, seq_length, num_heads, head_dim]
+        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
+            value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
+
+        if use_cache is True:
+            present = (key_layer, value_layer)
+        else:
+            present = None
+
+        # [batch_size, head_dim, q_length, k_length]
+        output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
+
+        # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
+        query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
+
+        # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
+        key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
+
+        # Raw attention scores. [batch_size * num_heads, q_length, k_length]
+        beta = 1.0 / self.layer_number
+
+        matmul_result = torch.baddbmm(
+            alibi,
+            query_layer.transpose(1, 0),
+            key_layer.transpose(1, 0).transpose(1, 2),
+            beta=beta,
+            alpha=(1.0 / self.norm_factor),
+        )
+
+        # change view to [batch_size, num_heads, q_length, k_length]
+        attention_scores = matmul_result.view(*output_size)
+
+        # attention scores and attention mask [b, np, sq, sk]
+        max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
+        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
+        attention_probs = self.attention_dropout(attention_probs)
+
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        # context layer shape: [batch_size, num_heads, q_length, head_dim]
+        output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+        # change view [k_length, batch_size x num_heads, head_dim]
+        value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
+
+        # change view [batch_size x num_heads, q_length, k_length]
+        attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+        # matmul: [batch_size * num_heads, q_length, head_dim]
+        context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
+
+        # change view [batch_size, num_heads, q_length, head_dim]
+        context_layer = context_layer.view(*output_size)
+
+        # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
+        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+        # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
+        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
+
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        # Output. [q_length, batch_size, hidden_size]
+
+        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+        output_tensor = self.dense(context_layer)
+        output = output_tensor.transpose(1, 0)
+
+        output = dropout_add(output, residual, self.hidden_dropout, self.training)
+
+        outputs = (output, present)
+        if output_attentions:
+            outputs += (attention_probs,)
+
+        return outputs
+
+
+class BloomMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
+        self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
+        self.hidden_dropout = config.hidden_dropout
+        self.gelu_impl = BloomGelu()
+
+    def forward(self, hidden_states, residual):
+        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+        intermediate_output = self.dense_4h_to_h(hidden_states)
+        output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+        return output
+
+
+class BloomBlock(nn.Module):
+    def __init__(self, config, layer_number=None):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+
+        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
+        self.n_head = config.n_head
+        self.self_attention = BloomAttention(config, layer_number=layer_number)
+        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
+
+        self.mlp = BloomMLP(config)
+
+        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+        self.hidden_dropout = config.hidden_dropout
+
+    def forward(
+        self,
+        hidden_states,
+        layer_past=None,
+        attention_mask=None,
+        head_mask=None,
+        use_cache=False,
+        output_attentions=False,
+        alibi=None,
+    ):
+        # hidden_states: [batch_size, seq_length, hidden_size]
+
+        # Layer norm at the beginning of the transformer layer.
+        layernorm_output = self.input_layernorm(hidden_states)
+
+        # Layer norm post the self attention.
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = hidden_states
+
+        # Self attention.
+        attn_outputs = self.self_attention(
+            layernorm_output,
+            residual,
+            layer_past=layer_past,
+            attention_mask=attention_mask,
+            alibi=alibi,
+            head_mask=head_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+        )
+
+        attention_output = attn_outputs[0]
+
+        outputs = attn_outputs[1:]
+
+        layernorm_output = self.post_attention_layernorm(attention_output)
+
+        # Get residual
+        if self.apply_residual_connection_post_layernorm:
+            residual = layernorm_output
+        else:
+            residual = attention_output
+
+        # MLP.
+        output = self.mlp(layernorm_output, residual)
+
+        if use_cache:
+            outputs = (output,) + outputs
+        else:
+            outputs = (output,) + outputs[1:]
+
+        return outputs  # hidden_states, present, attentions

+ 86 - 0
petals/bloom/from_pretrained.py

@@ -0,0 +1,86 @@
+"""
+Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+If necessary, one can rewrite this to implement a different behavior, such as:
+ - loading files from a local data source (e.g. S3)
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
+
+"""
+from __future__ import annotations
+
+from typing import Optional, OrderedDict, Union
+
+import torch
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from transformers.modeling_utils import WEIGHTS_NAME
+from transformers.utils.hub import cached_path, hf_bucket_url
+
+from petals.bloom import BloomBlock, BloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+CLIENT_BRANCH = "main"
+BLOCK_BRANCH_PREFIX = "block_"
+USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
+FORCE_DOWNLOAD = False
+RESUME_DOWNLOAD = False
+LOCAL_FILES_ONLY = False
+
+
+def load_pretrained_block(
+    converted_model_name_or_path: str,
+    block_index: int,
+    config: Optional[BloomConfig] = None,
+    torch_dtype: Union[torch.dtype, str] = "auto",
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
+) -> BloomBlock:
+    """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
+    if config is None:
+        config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+    block = BloomBlock(config, layer_number=block_index)
+    state_dict = _load_state_dict(
+        converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
+    block.load_state_dict(state_dict)
+
+    if torch_dtype == "auto":
+        with torch.no_grad():
+            for name, param in block.named_parameters():
+                assert name in state_dict, f"{name} not in state dict"
+                param.data = param.data.to(state_dict[name].dtype)
+    else:
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        block = block.to(dtype=torch_dtype)
+
+    report = block.load_state_dict(state_dict, strict=True)
+    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
+    return block
+
+
+def _load_state_dict(
+    pretrained_model_name_or_path: str,
+    block_index: Optional[int] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
+) -> OrderedDict[str, torch.Tensor]:
+    revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
+    archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
+
+    # Load from URL or cache if already cached
+    resolved_archive_file = cached_path(
+        archive_file,
+        cache_dir=cache_dir,
+        force_download=FORCE_DOWNLOAD,
+        proxies=None,
+        resume_download=RESUME_DOWNLOAD,
+        local_files_only=LOCAL_FILES_ONLY,
+        use_auth_token=use_auth_token,
+        user_agent=USER_AGENT,
+    )
+    state_dict = torch.load(resolved_archive_file, map_location="cpu")
+    return state_dict
+
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 583 - 0
petals/bloom/model.py

@@ -0,0 +1,583 @@
+"""
+PyTorch BLOOM model that implements several memory-efficient modes.
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from hivemind import use_hivemind_log_handler
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.bloom.configuration_bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
+from transformers.utils import logging
+
+from petals.bloom.block import BloomBlock
+
+use_hivemind_log_handler("in_root_logger")
+logger = logging.get_logger(__file__)
+
+_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
+_CONFIG_FOR_DOC = "BloomConfig"
+_TOKENIZER_FOR_DOC = "BloomTokenizer"
+
+
+BLOOM_START_DOCSTRING = r"""
+
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+BLOOM_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+            sequence tokens in the vocabulary.
+
+            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+            `input_ids`.
+
+            Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
+            Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+            `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+            their past given to this model should not be passed as `input_ids` as they have already been computed.
+        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+
+            If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
+            `past_key_values`).
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
+    BLOOM_START_DOCSTRING,
+)
+class BloomModel(BloomPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
+
+        self.embed_dim = config.hidden_size
+        self.n_head = config.n_head
+
+        # Embedding + LN Embedding
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+        self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+        # Transformer blocks
+        self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
+
+        # Final Layer Norm
+        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.word_embeddings
+
+    def set_input_embeddings(self, new_embeddings):
+        self.word_embeddings = new_embeddings
+
+    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        processor_class=_TOKENIZER_FOR_DOC,
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=BaseModelOutputWithPastAndCrossAttentions,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        if position_ids is not None:
+            logger.warning("position_ids are ignored in this bloom implementation")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if past_key_values is None:
+            past_key_values = tuple([None] * len(self.h))
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_head x N x N
+        # head_mask has shape n_layer x batch x n_head x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_hidden_states = () if output_hidden_states else None
+
+        # Compute alibi tensor: check build_alibi_tensor documentation
+        current_sequence_length = hidden_states.shape[1]
+        if past_key_values and past_key_values[0]:
+            current_sequence_length += past_key_values[0][0].shape[1]
+
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+
+                if use_cache:
+                    logger.warning(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, use_cache, output_attentions, alibi=None)
+
+                    return custom_forward
+
+                outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    None,
+                    attention_mask,
+                    head_mask[i],
+                )
+            else:
+                outputs = block(
+                    hidden_states,
+                    layer_past=layer_past,
+                    attention_mask=attention_mask,
+                    head_mask=head_mask[i],
+                    use_cache=use_cache,
+                    output_attentions=output_attentions,
+                    alibi=None,
+                )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        hidden_states = hidden_states.view(output_shape)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+    embeddings).
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class BloomForCausalLM(BloomPreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.transformer = BloomModel(config)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+        # only last token for inputs_ids if past is defined in kwargs
+        if past:
+            input_ids = input_ids[:, -1].unsqueeze(-1)
+
+        attention_mask = kwargs.get("attention_mask", None)
+        position_ids = kwargs.get("position_ids", None)
+
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+        else:
+            position_ids = None
+        return {
+            "input_ids": input_ids,
+            "past_key_values": past,
+            "use_cache": kwargs.get("use_cache"),
+            "position_ids": position_ids,
+            "attention_mask": attention_mask,
+        }
+
+    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        processor_class=_TOKENIZER_FOR_DOC,
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=CausalLMOutputWithCrossAttentions,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = transformer_outputs[0]
+
+        lm_logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = lm_logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
+
+    @staticmethod
+    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+        """
+        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+        beam_idx at every generation step.
+        """
+        return tuple(
+            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+            for layer_past in past
+        )
+
+
+@add_start_docstrings(
+    """
+    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class LMHead(nn.Module):
+    def __init__(self, config, word_embeddings: nn.Embedding):
+        super().__init__()
+        self.word_embeddings = word_embeddings
+        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+    @property
+    def in_features(self) -> int:
+        return self.word_embeddings.num_embeddings
+
+    @property
+    def out_features(self) -> int:
+        return self.word_embeddings.embedding_dim
+
+    @property
+    def weight(self):
+        return self.word_embeddings.weight
+
+    @property
+    def bias(self):
+        return None
+
+    def forward(self, hidden_states):
+        word_embeddings = self.word_embeddings.weight
+
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
+            lm_logits = self.chunked_forward(hidden_states)
+        else:
+            # Switch dtype in case word_embeddings are fp16/bf16
+            hidden_states = hidden_states.to(word_embeddings.dtype)
+            lm_logits = F.linear(hidden_states, word_embeddings).float()
+        return lm_logits
+
+    def chunked_forward(self, hidden_states):
+        """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
+        chunk_size: provides trade-off between efficiency and extra memory consumption.
+        """
+        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+
+        word_embeddings = self.word_embeddings.weight
+        num_embeddings = self.word_embeddings.num_embeddings
+
+        hidden_states = hidden_states.float()
+        output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
+
+        for i in range(0, num_embeddings, self.chunk_size):
+            chunk = word_embeddings[i : i + self.chunk_size].float()
+            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
+        return output
+
+
+@add_start_docstrings(
+    """
+    The Bloom Model transformer with a sequence classification head on top (linear layer).
+    [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+    (e.g. GPT-1) do.
+    Since it does classification on the last token, it requires to know the position of the last token. If a
+    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+    each row of the batch).
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class BloomForSequenceClassification(BloomPreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.transformer = BloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        processor_class=_TOKENIZER_FOR_DOC,
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=SequenceClassifierOutputWithPast,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = transformer_outputs[0]
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+        if self.config.pad_token_id is None:
+            sequence_lengths = -1
+        else:
+            if input_ids is not None:
+                sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+            else:
+                sequence_lengths = -1
+                logger.warning(
+                    f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+                    "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+                )
+
+        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+        loss = None
+        if labels is not None:
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(pooled_logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(pooled_logits, labels)
+        if not return_dict:
+            output = (pooled_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )

+ 0 - 0
src/bloom/ops.py → petals/bloom/ops.py


+ 5 - 0
petals/client/__init__.py

@@ -0,0 +1,5 @@
+from petals.client.inference_session import InferenceSession
+from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
+from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 322 - 0
petals/client/inference_session.py

@@ -0,0 +1,322 @@
+from __future__ import annotations
+
+import asyncio
+import itertools
+import time
+from typing import AsyncIterator, List, Optional
+
+import torch
+from hivemind import (
+    P2P,
+    MSGPackSerializer,
+    anext,
+    deserialize_torch_tensor,
+    get_logger,
+    nested_flatten,
+    serialize_torch_tensor,
+)
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import aiter_with_timeout
+
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
+
+logger = get_logger(__file__)
+
+
+class _ServerInferenceSession:
+    """
+    An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
+
+    :note: This class is *not* fault-tolerant out of the box.
+    """
+
+    def __init__(
+        self,
+        uid: ModuleUID,
+        rpc_info: RPCInfo,
+        inputs_queue: asyncio.Queue,
+        outputs_aiter: AsyncIterator,
+        *,
+        timeout: float,
+        max_length: int,
+        points: int = 0,
+    ):
+        self.uid, self.rpc_info = uid, rpc_info
+        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.timeout = timeout
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
+        self.stepped = False
+        self.closed = False
+
+    @classmethod
+    async def create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
+    ) -> _ServerInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await asyncio.wait_for(
+            stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
+            timeout,
+        )
+        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(
+        self,
+        new_hidden_states: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+        Inference step: send a chunk of input tesors and receive a chunk of outputs
+        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+        """
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+            assert prompts.shape[2] <= new_hidden_states.shape[1]
+            assert prompts.shape[3] == new_hidden_states.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY
+        else:
+            assert len(hypo_ids) == len(new_hidden_states)
+            assert hypo_ids.dtype == torch.int64
+
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states, prompts, hypo_ids)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(
+            self._step(
+                runtime_pb2.ExpertRequest(
+                    uid=self.uid,
+                    tensors=[
+                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
+                    ],
+                    metadata=self._serialized_metadata if not self.stepped else None,
+                )
+            )
+        )
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
+        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()
+
+
+class InferenceSession:
+    """
+    An interface to a multi-step *inference* session for a sequence of remote transformer blocks
+    """
+
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
+        self._sequence_manager = sequence_manager
+        self._p2p = p2p
+        self._closed = False
+        self._chosen_spans = []
+        self._server_sessions = []
+        self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
+        self._position = 0
+        self._max_length = max_length
+        self._metadata = metadata
+
+    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
+        server_sessions = []
+        try:
+            for span in chosen_spans:
+                stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+                span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
+                session = RemoteExpertWorker.run_coroutine(
+                    _ServerInferenceSession.create(
+                        stub,
+                        span_uids,
+                        rpc_info=self._sequence_manager.rpc_info,
+                        timeout=self._sequence_manager.timeout,
+                        max_length=self._max_length,
+                        **self._metadata,
+                    )
+                )
+                server_sessions.append(session)
+                session.__enter__()
+            return server_sessions
+        except:
+            self._exit_server_sessions(server_sessions)
+            raise
+
+    def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
+        for session in reversed(server_sessions):
+            try:
+                session.__exit__(None, None, None)
+            except Exception:
+                logger.debug("Caught exception while closing connection to server:", exc_info=True)
+
+    def __enter__(self) -> "InferenceSession":
+        assert not self._closed and not self._chosen_spans
+        return self
+
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+        assert not self._closed
+        if torch.is_grad_enabled():
+            logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+
+        n_blocks = len(self._sequence_manager)
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+
+        n_input_tokens = inputs.shape[1]
+        if self._position + n_input_tokens > self._max_length:
+            raise ValueError(
+                f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
+            )
+
+        server_idx = 0
+        block_idx = 0
+        recovery_until = -1  # Recovery mode is disabled until a failure happens
+        while block_idx < n_blocks:
+            for attempt_no in itertools.count():
+                logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
+                try:
+                    if attempt_no >= 1:
+                        self._sequence_manager.update_()
+                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
+                        # If there is a failed server session, this code closes it
+                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+
+                        n_prev_spans = len(self._chosen_spans)
+                        update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
+                        if attempt_no >= 1 and update_end > recovery_until:
+                            logger.info(
+                                f"Due to a server failure, remote attention caches "
+                                f"from block {block_idx} to {update_end} will be regenerated"
+                            )
+                        recovery_until = max(recovery_until, update_end)
+
+                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
+                        # make_sequence() could return a longer sequence
+                        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
+                        updated_sessions = self._enter_server_sessions(updated_spans)
+                        logger.debug(
+                            f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
+                        )
+
+                        # If there is a failed span, this code replaces it, otherwise it just adds new ones
+                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
+                        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
+                        recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
+                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
+                            len(updated_spans) - 1
+                        )
+                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
+                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
+                            f"{len(self._server_inputs)} inputs"
+                        )
+
+                    session = self._server_sessions[server_idx]
+                    span = self._chosen_spans[server_idx]
+
+                    if self._server_inputs[server_idx] is None:
+                        self._server_inputs[server_idx] = inputs
+                    elif self._server_inputs[server_idx].shape[1] == self._position:
+                        self._server_inputs[server_idx] = torch.cat(
+                            [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
+                        )
+                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
+                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
+                        f"position={self._position} n_input_tokens={n_input_tokens}"
+                    )
+
+                    if not session.stepped:
+                        inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
+                    else:
+                        inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+
+                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
+                    assert (
+                        inputs.shape == outputs.shape
+                    ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
+
+                    inputs = outputs
+                    server_idx += 1
+                    block_idx = span.end
+                    break
+                except Exception as e:
+                    delay = self._sequence_manager.get_retry_delay(attempt_no)
+                    logger.warning(
+                        f"Caught exception when running inference from block {block_idx} "
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
+                    )
+                    logger.debug("See detailed traceback below:", exc_info=True)
+                    time.sleep(delay)
+
+        self._position += n_input_tokens
+        return inputs
+
+    def close(self, *exc_details):
+        """Finish a given inference session, close the underlying connection"""
+        if not self._closed:
+            self._server_inputs.clear()
+            self._exit_server_sessions(self._server_sessions)
+            self._server_sessions.clear()
+            self._chosen_spans.clear()
+            self._closed = True
+
+    def __exit__(self, *exc_details):
+        self.close(*exc_details)
+
+    def __del__(self):
+        self.close()

+ 154 - 0
petals/client/remote_forward_backward.py

@@ -0,0 +1,154 @@
+"""
+Utility functions that call RPC forward or backward on a single remote server
+"""
+import asyncio
+from typing import Iterable, List, Sequence, Tuple
+
+import torch
+from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
+from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
+from hivemind.p2p import StubBase
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
+from hivemind.utils.streaming import split_for_streaming
+
+from petals.data_structures import ModuleUID, RPCInfo
+
+
+async def _forward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        timeout=timeout,
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def _backward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        timeout=timeout,
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def _forward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
+    outputs = aiter_with_timeout(outputs, timeout)
+    return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
+
+
+async def _backward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
+    grad_inputs = aiter_with_timeout(grad_inputs, timeout)
+    return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
+
+
+async def run_remote_forward(
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
+) -> Tuple[torch.Tensor, ...]:
+    """
+    Serializes input tensors and calls "rpc_forward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+    # detach to avoid pickling the computation graph
+    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+    forward_inputs = (inputs, kwargs)
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
+        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+    forward_inputs = nested_flatten(forward_inputs)
+    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
+        )
+    )
+
+    # call RPC on remote server
+    size = sum(t.element_size() * t.nelement() for t in inputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
+    else:
+        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
+
+    return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
+
+
+async def run_remote_backward(
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    inputs: torch.Tensor,
+    grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
+    timeout: float,
+    **kwargs,
+) -> Sequence[torch.Tensor]:
+    """
+    Serializes grad outputs and calls "rpc_backward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
+    )
+
+    size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
+    else:
+        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
+
+    return deserialized_grad_inputs

+ 328 - 0
petals/client/remote_generation.py

@@ -0,0 +1,328 @@
+from typing import List, Optional
+
+import torch
+from hivemind.utils.logging import get_logger
+
+from petals.utils.generation_algorithms import (
+    BeamSearchAlgorithm,
+    DecodingAlgorithm,
+    GreedyAlgorithm,
+    NucleusAlgorithm,
+    TopKAlgorithm,
+)
+from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
+
+logger = get_logger(__file__)
+
+
+class RemoteGenerationMixin:
+    """
+    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
+    The class exposes can be used for:
+        - *greedy decoding*.
+        - *multinomial sampling*.
+        - *beam-search decoding*
+
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
+    """
+
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        do_sample: Optional[bool] = None,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        num_beams: Optional[int] = 1,
+        bos_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        max_length: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        decoding_algorithm: Optional[DecodingAlgorithm] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        num_return_sequences: Optional[int] = None,
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head.
+
+        :param inputs: The input tokens to the model.
+        :param do_sample: Whether to sample from the model predictions or take the argmax.
+        :param temperature: The temperature to use for sampling.
+        :param top_k: The number of results to return.
+        :param top_p: The cumulative probability of results to return.
+        :param num_beams: The number of beams to use for beam search.
+        :param bos_token_id: The id of the beginning of sentence token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param pad_token_id: The id of the padding token.
+        :param max_new_tokens: The maximum number of tokens to generate.
+        :param decoding_algorithm: The decoding algorithm to use.
+        :param provided_constraints: A list of constraints to use.
+        :param model_kwargs: Additional arguments to pass to the model.
+        :param num_return_sequences: How many hypothesis from the beam will be in output.
+        """
+
+        assert (
+            model_kwargs.get("logits_processor", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+        assert (
+            model_kwargs.get("logits_wrapper", None) is None
+        ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+        assert (
+            model_kwargs.get("stopping_criteria", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+        if inputs is not None:
+            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
+        prefix_length = 0 if inputs is None else inputs.size(1)
+        prefix_length += self.config.pre_seq_len
+
+        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+
+        batch_size = inputs.size(0)
+
+        assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
+        if max_length is not None and max_new_tokens is None:
+            max_new_tokens = max_length - prefix_length
+            assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+        elif max_length is None and max_new_tokens is not None:
+            max_length = prefix_length + max_new_tokens
+
+        if inputs is None:
+            assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
+            inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
+
+        if decoding_algorithm is None:
+            if do_sample:
+                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+            elif num_beams is not None and num_beams > 1:
+                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
+            else:
+                decoding_algorithm = GreedyAlgorithm()
+
+        if num_beams > 1:
+            inputs = torch.cat([inputs] * num_beams, dim=0)
+            if batch_size > 1:
+                # TODO: resolve padding problem
+                logger.warning(
+                    f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
+                )
+
+        if num_return_sequences is None:
+            num_return_sequences = 1
+
+        assert num_return_sequences <= num_beams, (
+            f"You want more sequences than the beam has."
+            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
+        )
+
+        constraints = self._get_constraints(
+            inputs=inputs,
+            eos_token_id=eos_token_id,
+            pad_token_id=pad_token_id,
+            provided_constraints=provided_constraints,
+        )
+
+        with self.transformer.h.inference_session(max_length=max_length) as sess:
+            outputs = []
+            # Find samples with padded inputs.
+            # They will be changed before all of the samples have right length.
+            if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
+                outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+            else:
+                outputs += [inputs]
+            last_token_id = None
+            seq_idx = outputs[0].size(1)
+            hypo_ids = torch.arange(outputs[0].size(0))
+            while True:
+                embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
+                if self.config.pre_seq_len > 0 and len(outputs) == 1:
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
+                    embs = torch.cat([prompts, embs], dim=1)
+                embs = self.transformer.word_embeddings_layernorm(embs)
+                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+                hidden_state = self.transformer.ln_f(hidden_state)
+                lm_logits = self.lm_head(hidden_state)
+
+                for constraint in constraints:
+                    lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
+                last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+
+                # If some samples were padded, change only these samples
+                if seq_idx < inputs.size(1):
+                    pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
+                    last_token_id = (~pad_token_mask) * inputs[
+                        :, seq_idx : seq_idx + 1
+                    ] + pad_token_mask * last_token_id
+
+                # TODO: refactor outputs
+                if num_beams > 1:
+                    for i in range(len(outputs), 1, -1):
+                        outputs[i - 1] = outputs[i - 1][hypo_ids]
+
+                outputs.append(last_token_id)
+                seq_idx += 1
+                if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
+                    break
+
+        outputs = torch.cat(outputs, dim=-1)
+
+        if num_beams > 1:
+            pre_return_idx = [
+                torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
+            ]
+            return_idx = torch.cat(pre_return_idx, dim=0)
+            outputs = outputs[return_idx]
+
+        return outputs
+
+    def greedy_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses greedy search.
+
+        :param input_ids: The input tokens to the model.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        """
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=GreedyAlgorithm(),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def sample(
+        self,
+        input_ids: torch.LongTensor,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
+
+        :param: input_ids: The input tokens to the model.
+        :param: temperature: The temperature to use for sampling.
+        :param: top_k: The number of samples to use for top_k sampling.
+        :param: top_p: The probability of using top_p sampling.
+        :param: max_length: The maximum length of the sequence to generate.
+        :param: pad_token_id: The id of the padding token.
+        :param: eos_token_id: The id of the end of sentence token.
+        :param: provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        num_beams: int = 1,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses beam search.
+
+        :param input_ids: The input tokens to the model.
+        :param num_beams: The number of beams to use.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+        decoding_algorithm = BeamSearchAlgorithm(
+            num_beams=num_beams,
+            batch_size=input_ids.size(0),
+        )
+        return self.generate(
+            inputs=input_ids,
+            num_beams=num_beams,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=decoding_algorithm,
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def beam_sample(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def group_beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def _choose_sample_algorithm(
+        self,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+    ) -> DecodingAlgorithm:
+        if (top_k is not None) and (top_p is not None):
+            raise ValueError("You have to provide only top_k or top_p for sampling")
+        if top_k:
+            return TopKAlgorithm(top_k, temperature)
+        elif top_p:
+            return NucleusAlgorithm(top_p, temperature)
+
+    def _get_constraints(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+    ) -> List[ABCBloomConstraint]:
+        constraints = []
+        constraints.extend(provided_constraints)
+        constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
+        return constraints

+ 198 - 0
petals/client/remote_model.py

@@ -0,0 +1,198 @@
+# this code is in active development, interfaces may change
+from typing import List, Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind import get_logger, use_hivemind_log_handler
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+
+from petals.bloom.model import (
+    BloomConfig,
+    BloomForCausalLM,
+    BloomForSequenceClassification,
+    BloomModel,
+    BloomPreTrainedModel,
+    LMHead,
+)
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.utils.misc import DUMMY
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class DistributedBloomConfig(BloomConfig):
+    """
+    A bloom config that contains information about DHT peers.
+    To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
+    """
+
+    initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
+    dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
+    dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+
+
+class DistributedBloomModel(BloomModel):
+    """BloomModel, but all transformer layers are hosted by the swarm"""
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
+        assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+
+        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
+        super().__init__(config)
+        assert len(self.h) == 0
+        config.n_layer = n_layer
+
+        dht = (
+            config.dht
+            if config.dht is not None
+            else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+        )
+        assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
+        self.h = RemoteSequential(config, dht, config.dht_prefix)
+
+        # Forbid accumulate grads for embeddings and layernorm
+        self.set_requires_grad(False)
+
+        if config.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            if config.tuning_mode == "deep_ptune":
+                self.intermediate_prompt_embeddings = nn.Embedding(
+                    self.pre_seq_len,
+                    config.num_hidden_layers * config.hidden_size
+                    # ^-- TODO: should be num_hidden_layers - 1
+                )
+                self.intermediate_prompt_embeddings.weight.data.zero_()
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad = value
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.config.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+        return prompts, intermediate_prompts
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
+
+
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
+    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.transformer = DistributedBloomModel(config)
+        self.lm_head = LMHead(config, self.transformer.word_embeddings)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.transformer.word_embeddings
+
+    def get_output_embeddings(self):
+        if self.config.tie_word_embeddings:
+            return None
+        return self.lm_head
+
+    def set_input_embeddings(self, new_embeddings: nn.Embedding):
+        assert isinstance(new_embeddings, nn.Embedding)
+        self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
+        assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
+
+    def set_output_embeddings(self, new_lm_head: nn.Linear):
+        with torch.no_grad():
+            self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
+            self.lm_head.bias[...] = new_lm_head.bias
+
+
+class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()

+ 103 - 0
petals/client/remote_sequential.py

@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from typing import Optional, Union
+
+import torch
+from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from torch import nn
+
+import src
+from petals.client.inference_session import InferenceSession
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
+from petals.data_structures import UID_DELIMITER
+from petals.utils.misc import DUMMY
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequential(nn.Module):
+    """
+    A sequence of transformer blocks hosted by the swarm.
+    """
+
+    def __init__(
+        self,
+        config: src.DistributedBloomConfig,
+        dht: DHT,
+        dht_prefix: Optional[str] = None,
+        p2p: Optional[P2P] = None,
+        sequence_manager: Optional[RemoteSequenceManager] = None,
+    ):
+        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
+        super().__init__()
+        self.config = config
+        self.dht = dht
+        self.dht_prefix = dht_prefix or config.dht_prefix
+        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
+
+        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
+        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
+        if sequence_manager is None:
+            logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
+            self.is_subsequence = False
+        else:
+            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+            self.sequence_manager = sequence_manager
+            assert isinstance(sequence_manager.block_uids, list)
+            self.is_subsequence = self.sequence_manager.block_uids != block_uids
+
+    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+        return outputs
+
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
+        assert isinstance(ix, (int, slice))
+        if isinstance(ix, int):
+            return RemoteTransformerBlock(
+                self.config,
+                self.dht,
+                dht_prefix=self.dht_prefix,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
+        else:
+            return RemoteSequential(
+                self.config,
+                self.dht,
+                dht_prefix=self.dht_prefix,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
+
+    def __iter__(self):
+        for block_index in range(len(self)):
+            yield self[block_index]
+
+    def __len__(self):
+        return len(self.sequence_manager)
+
+    def inference_session(self, **kwargs) -> InferenceSession:
+        self.sequence_manager.update_()
+        return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
+
+    def extra_repr(self) -> str:
+        return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+
+
+class RemoteTransformerBlock(RemoteSequential):
+    """Single transformer block hosted by swarm
+
+    This class is deprecated and kept for backward compatibility.
+    It will be removed soon in favor of using ``RemoteSequential`` directly.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert len(self) == 1, "Remote Block is a sequence size 1"
+
+    def extra_repr(self):
+        return f"{self.sequence_manager.block_uids[0]}"

+ 167 - 0
petals/client/sequence_manager.py

@@ -0,0 +1,167 @@
+from __future__ import annotations
+
+import random
+import threading
+from typing import List, Optional, Sequence, Tuple, Union
+
+from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from petals.client.spending_policy import NoSpendingPolicy
+from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.dht_utils import get_remote_module_infos
+from petals.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequenceManager:
+    """
+    Keeps and updates the meta-information about which peers host which blocks.
+    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        block_uids: Sequence[ModuleUID],
+        p2p: P2P,
+        max_retries: int = 3,
+        timeout: float = 5,
+        min_backoff: float = 1,
+    ):
+        assert len(block_uids) > 0, "Sequences must contain at least one block"
+        self.dht, self.p2p = dht, p2p
+        self.block_uids: List[ModuleUID] = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
+        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
+        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+        self.last_update_time: DHTExpiration = -float("inf")
+        self.max_retries = max_retries
+        self.timeout, self.min_backoff = timeout, min_backoff
+        self._rpc_info = None
+        self.lock_changes = threading.Lock()
+        self.update_()
+
+        for uid, info in zip(self.block_uids, self.block_infos):
+            assert info is not None, f"Found no remote peers for block {uid}"
+        assert self.spans_by_priority and self.spans_containing_block
+
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
+        """
+        Form a sequence of remote servers that collectively serve all consecutive layers
+
+        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        """
+        end_index = end_index if end_index is not None else len(self.block_uids)
+        span_sequence = []
+        current_index = start_index
+        while current_index < end_index:
+            candidate_spans = self.spans_containing_block[current_index]
+            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
+            current_index = chosen_span.end
+
+        return span_sequence
+
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
+        """Get a RemoteSequenceManager for a sub-sequence of blocks"""
+        assert isinstance(ix, (int, slice))
+        if not isinstance(ix, slice):
+            ix = slice(int(ix), int(ix) + 1, 1)
+        with self.lock_changes:
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
+            subseq.block_infos = self.block_infos[ix]
+            subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
+            subseq.last_update_time = self.last_update_time
+        return subseq
+
+    def update_(self):
+        with self.lock_changes:
+            self.update_block_infos_()
+            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+
+    def update_block_infos_(self):
+        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.warning(f"Found no block info for block {uid}")
+                continue
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+            if not info.servers:
+                logger.warning(f"Found no active peers for block {uid}")
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+            self.block_infos[block_index] = info
+
+    @staticmethod
+    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(block_infos):
+            if info is not None:
+                for peer_id, server in info.servers.items():
+                    if server.state != ServerState.ONLINE:
+                        continue
+                    if peer_id not in active_spans:
+                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    else:  # peer_id in active_spans
+                        active_spans[peer_id].end = block_index + 1
+
+            for peer_id in list(active_spans.keys()):
+                if (
+                    info is None
+                    or peer_id not in info.servers
+                    or info.servers[peer_id].state != ServerState.ONLINE
+                    or block_index == len(block_infos) - 1
+                ):
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans, f"spans: {active_spans}"
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        return closed_spans, spans_containing_block
+
+    def __len__(self):
+        return len(self.block_uids)
+
+    @property
+    def rpc_info(self):
+        """Return the rpc_info queried from one of the servers that hold the first block"""
+        if self._rpc_info is None:
+            retries = 0
+            for i in range(self.max_retries):
+                try:
+                    self.update_()
+                    peer_id = random.choice(list(self.block_infos[0].servers.keys()))
+                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+                    outputs = RemoteExpertWorker.run_coroutine(
+                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+                    )
+                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                    break
+                except Exception as e:
+                    retries += 1
+                    if retries >= self.max_retries:
+                        raise e
+                    else:
+                        logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
+        return self._rpc_info
+
+    def get_retry_delay(self, attempt_no: int) -> float:
+        if attempt_no == 0:
+            return 0
+        return self.min_backoff * 2 ** (attempt_no - 1)

+ 236 - 0
petals/client/sequential_autograd.py

@@ -0,0 +1,236 @@
+"""
+A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
+"""
+import asyncio
+import itertools
+from collections import deque
+from typing import List, Optional, Sequence, Tuple
+
+import torch
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.utils.logging import get_logger
+
+from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
+
+logger = get_logger(__file__)
+
+MAX_TOKENS_IN_BATCH = 1024
+
+
+async def sequential_forward(
+    inputs: torch.Tensor,
+    prompts: torch.Tensor,
+    sequence_manager: RemoteSequenceManager,
+    start_index: int = 0,
+    end_index: Optional[int] = None,
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+    """
+    Constructs a routing path from <start_index> to <end_index>.
+    Performs chained forward for each subsequence of blocks on the path.
+    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+    """
+
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+
+    end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
+    assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+    assert is_dummy(prompts) or len(prompts) == len(
+        sequence_manager.block_uids
+    )  # should be n_layers - 1 but add extra prompts for convenience
+
+    sequences = deque()
+    intermediate_inputs = []
+    done_sequences = []
+    outputs = inputs
+
+    block_idx = start_index
+    while block_idx < end_index:
+        for attempt_no in itertools.count():
+            logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
+            try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                if not sequences or attempt_no >= 1:
+                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
+                    # make_sequence() could return a longer sequence
+                    sequences[-1].end = min(sequences[-1].end, end_index)
+                    logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
+
+                span = sequences.popleft()
+
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                (outputs,) = await run_remote_forward(
+                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
+                )
+
+                assert isinstance(outputs, torch.Tensor)
+                assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
+
+                # Save intermediate inputs and subsequences if the forward is already done for them
+                intermediate_inputs.append(inputs)
+                done_sequences.append(span)
+
+                inputs = outputs
+                block_idx = span.end
+                break
+            except Exception as e:
+                delay = sequence_manager.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when running forward from block {block_idx} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
+
+    return outputs, intermediate_inputs, done_sequences
+
+
+async def sequential_backward(
+    grad_outputs: Sequence[torch.Tensor],
+    intermediate_inputs: List[torch.Tensor],
+    prompts: torch.Tensor,
+    forward_sequences: List[RemoteSpanInfo],
+    sequence_manager: RemoteSequenceManager,
+) -> Sequence[torch.Tensor]:
+    """
+    Performs chained backward for each forward subsequence.
+    If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
+    """
+    assert len(intermediate_inputs) == len(forward_sequences)
+
+    grad_prompts_reversed = []
+    while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+        inputs = intermediate_inputs.pop()
+        span = forward_sequences.pop()
+        for attempt_no in itertools.count():
+            logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
+            try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                    _, backup_inputs, backup_sequences = await sequential_forward(
+                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                    )
+                    assert len(backup_inputs) == len(backup_sequences)
+                    assert backup_sequences[0].start == span.start
+                    assert backup_sequences[-1].end == span.end
+
+                    intermediate_inputs.extend(backup_inputs)
+                    forward_sequences.extend(backup_sequences)
+                    inputs = intermediate_inputs.pop()
+                    span = forward_sequences.pop()
+
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                grad_outputs, *span_grad_prompts = await run_remote_backward(
+                    span_uids,
+                    stub,
+                    sequence_manager.rpc_info,
+                    inputs,
+                    grad_outputs,
+                    prompts[span.start : span.end],
+                    timeout=sequence_manager.timeout,
+                )
+                grad_outputs = [grad_outputs]
+                grad_prompts_reversed.extend(span_grad_prompts)
+                break
+            except Exception as e:
+                delay = sequence_manager.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when running backward between blocks {span.start}-{span.end} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
+
+    # For now, we do not support mixed dummy and grad prompts
+    # Concat in num_layer dimension
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+    return grad_outputs, grad_prompts
+
+
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
+    """Wrapper for asyncio.gather to perform parallel sequential forwards"""
+    return await asyncio.gather(
+        *[
+            sequential_forward(input_batch, prompt_batch, sequence_manager)
+            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+        ]
+    )
+
+
+async def _gather_backward(
+    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+):
+    """Wrapper for asyncio.gather to perform parallel sequential backwards"""
+    return await asyncio.gather(
+        *[
+            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
+            )
+        ]
+    )
+
+
+class _RemoteSequentialAutogradFunction(torch.autograd.Function):
+    """
+    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
+    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
+    """
+
+    @staticmethod
+    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+        batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
+        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        if is_dummy(prompts):
+            prompt_batches = [DUMMY] * len(input_batches)
+        else:
+            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
+
+        sequence_manager.rpc_info  # lazy init
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
+        assert len(outputs) == len(input_batches)
+
+        output_batches = [output[0] for output in outputs]
+        intemediate_input_batches = [output[1] for output in outputs]
+        sequences_for_batches = [output[2] for output in outputs]
+
+        ctx.prompt_batches = prompt_batches
+        ctx.sequence_manager = sequence_manager
+        ctx.intemediate_input_batches = intemediate_input_batches
+        ctx.sequences_for_batches = sequences_for_batches
+        return torch.cat(output_batches, dim=0)
+
+    @staticmethod
+    def backward(ctx, grad_outputs: torch.Tensor):
+        intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
+        forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
+        ctx.sequence_manager.rpc_info  # lazy init
+
+        batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
+        grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
+        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
+
+        outputs = RemoteExpertWorker.run_coroutine(
+            _gather_backward(
+                grad_output_batches,
+                intermediate_input_batches,
+                ctx.prompt_batches,
+                forward_sequences,
+                ctx.sequence_manager,
+            )
+        )
+        grad_input_batches = [output[0][0] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
+
+        grad_inputs = torch.cat(grad_input_batches, dim=0)
+        dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
+        grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
+        return (grad_inputs, grad_prompts, None)

+ 0 - 0
src/client/spending_policy.py → petals/client/spending_policy.py


+ 0 - 0
src/constants.py → petals/constants.py


+ 0 - 0
src/data_structures.py → petals/data_structures.py


+ 180 - 0
petals/dht_utils.py

@@ -0,0 +1,180 @@
+"""
+Utilities for declaring and retrieving active model layers using a shared DHT.
+"""
+from __future__ import annotations
+
+import math
+from functools import partial
+from typing import Dict, List, Optional, Sequence, Union
+
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
+
+import src
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def declare_active_modules(
+    dht: DHT,
+    uids: Sequence[ModuleUID],
+    expiration_time: DHTExpiration,
+    state: ServerState,
+    throughput: float,
+    wait: bool = True,
+) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
+    """
+    Declare that your node serves the specified modules; update timestamps if declared previously
+
+    :param uids: a list of module ids to declare
+    :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param throughput: specify your performance in terms of compute throughput
+    :param expiration_time: declated modules will be visible for this many seconds
+    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+    """
+    if isinstance(uids, str):
+        uids = [uids]
+    if not isinstance(uids, list):
+        uids = list(uids)
+    for uid in uids:
+        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+    return dht.run_coroutine(
+        partial(
+            _declare_active_modules,
+            uids=uids,
+            expiration_time=expiration_time,
+            state=state,
+            throughput=throughput,
+        ),
+        return_future=not wait,
+    )
+
+
+async def _declare_active_modules(
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    expiration_time: DHTExpiration,
+    state: ServerState,
+    throughput: float,
+) -> Dict[ModuleUID, bool]:
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    return await node.store_many(
+        keys=uids,
+        subkeys=[dht.peer_id.to_base58()] * len(uids),
+        values=[(state.value, throughput)] * len(uids),
+        expiration_time=expiration_time,
+        num_workers=num_workers,
+    )
+
+
+def get_remote_sequence(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> Union[src.RemoteSequential, MPFuture]:
+    return RemoteExpertWorker.run_coroutine(
+        _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_remote_sequence(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+) -> src.RemoteSequential:
+    uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
+    p2p = await dht.replicate_p2p()
+    manager = src.RemoteSequenceManager(dht, uids, p2p)
+    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+
+
+def get_remote_module(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
+    """
+    :param uid_or_uids: find one or more modules with these ids from across the DHT
+    :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock]
+    """
+    return RemoteExpertWorker.run_coroutine(
+        _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_remote_module(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    p2p = await dht.replicate_p2p()
+    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    modules = [
+        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+    ]
+    return modules[0] if single_uid else modules
+
+
+def get_remote_module_infos(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    expiration_time: Optional[DHTExpiration] = None,
+) -> List[Optional[RemoteModuleInfo]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
+    )
+    return infos[0] if single_uid else infos
+
+
+async def _get_remote_module_infos(
+    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteModuleInfo]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        metadata = found[uid]
+        if metadata is None or not isinstance(metadata.value, dict):
+            if metadata is not None:
+                logger.error(f"Incorrect metadata for {uid}: {metadata}")
+            continue
+        servers = {}
+        for peer_id, server_info in metadata.value.items():
+            try:
+                peer_id = PeerID.from_base58(peer_id)
+                state, throughput = server_info.value
+                if not (
+                    isinstance(state, int)
+                    and isinstance(throughput, float)
+                    and math.isfinite(throughput)
+                    and throughput >= 0.0
+                ):
+                    raise ValueError(f"Invalid server info: {server_info}")
+                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+            except (TypeError, ValueError) as e:
+                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+        if servers:
+            modules[i] = RemoteModuleInfo(uid, servers)
+    return modules

+ 0 - 0
src/server/__init__.py → petals/server/__init__.py


+ 87 - 0
petals/server/backend.py

@@ -0,0 +1,87 @@
+"""Code for serving bloom blocks via hivemind-server"""
+from typing import Any, Dict, Optional, Sequence, Tuple
+
+import torch
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+
+from petals.bloom.from_pretrained import BloomBlock
+from petals.server.cache import MemoryCache
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.utils.misc import is_dummy
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class TransformerBackend(ModuleBackend):
+    """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert isinstance(self.module, BloomBlock)
+        self.memory_cache = memory_cache
+        for name, param in self.module.named_parameters():
+            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+        for name, buf in self.module.named_buffers():
+            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+
+        max_batch_size = self.forward_pool.max_batch_size
+        self.inference_pool = PrioritizedTaskPool(
+            self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
+        )
+        self.forward_pool = PrioritizedTaskPool(
+            self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+        )
+        self.backward_pool = PrioritizedTaskPool(
+            self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
+        )
+
+        assert backend_dtype is not None
+        self.dtype = backend_dtype
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
+
+    def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        with torch.inference_mode():
+            attention_cache_handle = int(cache_metadata[0, 0].item())
+            prefix_length = int(cache_metadata[0, 1].item())
+            (hidden_states, hypo_ids) = inputs
+            assert (
+                hidden_states.ndim == 3
+            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+
+            with self.memory_cache.use_cache(attention_cache_handle) as cache:
+                assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                if not is_dummy(hypo_ids):
+                    assert hypo_ids.shape[0] == cache.shape[1]
+                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
+                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+                logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
+                hidden_states, (new_k, new_v) = self.module.forward(
+                    hidden_states, layer_past=layer_past, use_cache=True
+                )
+
+                # todo remove these asserts once we pass all tests
+                new_length = new_v.shape[1]
+                assert new_length > prefix_length
+                assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
+                assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
+                assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
+                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+                return (hidden_states,)
+
+    def get_pools(self) -> Sequence[PrioritizedTaskPool]:
+        return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 115 - 0
petals/server/block_selection.py

@@ -0,0 +1,115 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+from hivemind import PeerID, get_logger
+
+from petals.data_structures import RemoteModuleInfo, ServerState
+
+__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
+
+logger = get_logger(__file__)
+
+
+@dataclass
+class Span:
+    start: int
+    end: int
+    throughput: float
+
+    @property
+    def length(self):
+        return self.end - self.start
+
+    def move_to(self, new_start: int) -> None:
+        self.start, self.end = new_start, new_start + self.length
+
+
+def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
+    spans = {}
+    throughputs = np.zeros(len(module_infos))
+    for block, module in enumerate(module_infos):
+        if module is None:
+            continue
+
+        # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
+        # If the order were not defined, we would get slightly different values due to floating point errors,
+        # which may cause excess block replacements.
+        for peer_id, server in sorted(module.servers.items()):
+            if server.state == ServerState.OFFLINE:
+                continue
+
+            if peer_id in spans:
+                spans[peer_id].start = min(spans[peer_id].start, block)
+                spans[peer_id].end = max(spans[peer_id].start, block + 1)
+            else:
+                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
+
+            throughputs[block] += server.throughput
+
+    return spans, throughputs
+
+
+def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
+    options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1))
+    return min(options)[-1]
+
+
+def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+    _, throughputs = _compute_spans(module_infos)
+    start = _choose_best_start(throughputs, num_blocks)
+    return list(range(start, start + num_blocks))
+
+
+def should_choose_other_blocks(
+    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
+) -> bool:
+    if balance_quality > 1.0:
+        return True  # Forces rebalancing on each check (may be used for debugging purposes)
+
+    spans, throughputs = _compute_spans(module_infos)
+    initial_throughput = throughputs.min()
+    eps = 1e-3
+
+    assert local_peer_id in spans, "Span served by this server is not present in the DHT"
+    local_span = spans[local_peer_id]
+    throughputs[local_span.start : local_span.end] -= local_span.throughput * (1 + eps)
+    # Without (1 + eps) here, we would sometimes subtract a value slightly less than local_span.throughput
+    # due to the floating point error, which would cause excess block replacements.
+    # Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer
+    # the previous server position in case of other things being almost equal.
+
+    new_start = _choose_best_start(throughputs, local_span.length)
+    if local_span.start == new_start:
+        return False  # This server is on its best place already
+
+    throughputs[local_span.start : local_span.end] += local_span.throughput * eps
+    local_span.move_to(new_start)
+    throughputs[local_span.start : local_span.end] += local_span.throughput
+
+    moved = True
+    while moved:
+        servers = list(spans.keys())
+        np.random.shuffle(servers)
+
+        moved = False
+        for peer_id in servers:
+            span = spans[peer_id]
+            throughputs[span.start : span.end] -= span.throughput * (1 + eps)
+
+            new_start = _choose_best_start(throughputs, span.length)
+
+            throughputs[span.start : span.end] += span.throughput * eps
+            if span.start != new_start:
+                span.move_to(new_start)
+                moved = True
+            throughputs[span.start : span.end] += span.throughput
+
+    new_throughput = throughputs.min()
+    if new_throughput < initial_throughput or new_throughput < eps:
+        return False
+
+    actual_quality = initial_throughput / new_throughput
+    logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
+
+    return actual_quality < balance_quality - eps

+ 0 - 0
src/server/cache.py → petals/server/cache.py


+ 470 - 0
petals/server/handler.py

@@ -0,0 +1,470 @@
+import asyncio
+import contextlib
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
+
+import torch
+from async_timeout import timeout
+from hivemind import (
+    DHT,
+    MSGPackSerializer,
+    P2PContext,
+    TensorDescriptor,
+    deserialize_tensor_stream,
+    deserialize_torch_tensor,
+    nested_flatten,
+    serialize_torch_tensor,
+)
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
+from hivemind.utils.logging import get_logger
+from hivemind.utils.streaming import split_for_streaming
+
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID
+from petals.server.backend import TransformerBackend
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
+from petals.utils.misc import DUMMY, is_dummy
+
+logger = get_logger(__file__)
+
+
+class TransformerConnectionHandler(ConnectionHandler):
+    """Handles three request types: forward, backward and forward-incremental (inference)"""
+
+    module_backends: Dict[ModuleUID, TransformerBackend]
+
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        *,
+        inference_max_length: int,
+        request_timeout: float,
+        session_timeout: float,
+        step_timeout: float,
+        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+    ):
+        super().__init__(dht, module_backends)
+        for module_backend in self.module_backends.values():
+            assert isinstance(module_backend, TransformerBackend)
+        self.inference_max_length = inference_max_length
+        self.request_timeout = request_timeout
+        self.session_timeout, self.step_timeout = session_timeout, step_timeout
+        self._prioritizer = task_prioritizer
+
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor], Dict]:
+        block_uid, metadata = None, None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal block_uid, metadata
+
+            if block_uid is None:
+                block_uid = req.uid
+            elif block_uid != req.uid:
+                raise ValueError("Block uids differ in one request")
+
+            if metadata is None:
+                metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        assert isinstance(block_uid, str) and isinstance(metadata, dict)
+        return block_uid, inputs, metadata
+
+    async def rpc_inference(
+        self,
+        requests: AsyncIterator[runtime_pb2.ExpertRequest],
+        context: P2PContext,
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        """Compute a single step of inference using attention cache; update attention cache accordingly."""
+
+        async with timeout(self.session_timeout):
+            request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_inference.open", requested_uids, context)
+            try:
+                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+                max_length = metadata.get("max_length")
+                points = metadata.get("points", 0)
+
+                if not requested_uids:
+                    raise ValueError("User must specify at least one block for inference, but got none")
+                assert isinstance(
+                    max_length, int
+                ), f"rpc_inference metadata must contain int max_length, got {max_length}"
+                assert isinstance(
+                    points, (float, int)
+                ), f"rpc_inference should have number of points as a number or None, got {points}"
+                if not 0 <= max_length <= self.inference_max_length:
+                    raise ValueError(
+                        f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
+                    )
+
+                point_per_piece = points / max_length if max_length > 0 else 0.0
+                batch_size = request.tensors[0].size[0] if request.tensors else 1
+
+                cache_metadata = torch.tensor(
+                    [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+                )  # [cache_handle, prefix_length]
+                prefix_length = 0
+
+                async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
+                    assert len(cache_handles) == len(requested_backends)
+                    while request.tensors:  # iterate while user is willing to supply tensors
+                        hidden_states, prompts, hypo_ids = [
+                            deserialize_torch_tensor(tensor) for tensor in request.tensors
+                        ]
+
+                        # Cast inputs to backend dtype
+                        hidden_states = hidden_states.to(requested_backends[0].dtype)
+                        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+                        # parse deep prompts (optional argument)
+                        if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+                            prompts = [DUMMY] * len(requested_backends)
+                        else:
+                            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+                        if not (len(requested_backends) == len(prompts)):
+                            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
+                        if prefix_length + length_increment > max_length:
+                            raise ValueError(
+                                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                                f" exceeds pre-allocated maximum {max_length}"
+                            )
+
+                        # run request tensors through all requested modules, update caches
+                        for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                            if not is_dummy(prompt):
+                                hidden_states[:, : prompt.shape[1]] += prompt
+
+                            cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                            assert isinstance(
+                                hidden_states, torch.Tensor
+                            ), f"hidden states must be tensor, got {type(hidden_states)}"
+                            assert (
+                                hidden_states.ndim == 3
+                            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                            assert isinstance(
+                                backend.inference_pool, PrioritizedTaskPool
+                            ), "petals support only prioritized pools"
+                            priority = self._prioritizer.prioritize(
+                                cache_metadata,
+                                hidden_states,
+                                hypo_ids,
+                                points=point_per_piece / len(requested_backends),
+                                backend=backend,
+                                type="inference",
+                            )
+                            (hidden_states,) = await backend.inference_pool.submit_task(
+                                cache_metadata, hidden_states, hypo_ids, priority=priority
+                            )
+
+                        # serialize and send last layer outputs
+                        yield runtime_pb2.ExpertResponse(
+                            tensors=[
+                                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                                for result, proto in zip(
+                                    (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
+                                )
+                            ]
+                        )
+
+                        # prepare for next step
+                        prefix_length += hidden_states.shape[1]
+                        request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            finally:
+                self._log_request("rpc_inference.close", requested_uids, context)
+
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        async with timeout(self.request_timeout):
+            # Parse request and prepare backends
+            flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_forward", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_forward should have number of points as number or None, got {points}"
+
+            hidden_states = await _rpc_forward(
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+            assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+
+            # Serialize output and respond to client
+            return runtime_pb2.ExpertResponse(
+                tensors=[
+                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                    for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+                ]
+            )
+
+    async def rpc_forward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        async with timeout(self.request_timeout):
+            # Parse requests and prepare backends
+            uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
+            requested_uids = self._check_uids(uid_str)
+            self._log_request("rpc_forward_stream", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_forward_stream should have number of points as number or None, got {points}"
+
+            hidden_states = await _rpc_forward(
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+            assert (
+                isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+            ), "hidden_states must be a 3d tensor"
+
+            # Serialize the overall output
+            serialized_output = [
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+            ]
+
+            # Split the serialized_output for streaming and respond to client
+            output_split = [
+                part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            ]
+            async for part in as_aiter(*output_split):
+                yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        async with timeout(self.request_timeout):
+            # Parse requests and prepare backends
+            flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+            requested_uids = self._check_uids(request.uid)
+            self._log_request("rpc_backward", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_backward should have number of points as number or None, got {points}"
+
+            grads = await _rpc_backward(
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+
+            # Modify grad_inputs_schema to support grad_prompts
+            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+
+            grad_inputs_schema_with_prompts = (
+                requested_backends[0].args_schema * len(grads),
+                requested_backends[0].kwargs_schema,
+            )  # TODO generalize
+
+            # Serialize the overall grad_input and respond
+            return runtime_pb2.ExpertResponse(
+                tensors=[
+                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                    for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+                ]
+            )
+
+    async def rpc_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        async with timeout(self.request_timeout):
+            uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
+            requested_uids = self._check_uids(uids_header)
+            self._log_request("rpc_backward_stream", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            points = metadata.get("points", 0)
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_backward_stream should have number of points as number or None, got {points}"
+
+            grads = await _rpc_backward(
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+            )
+
+            # Modify grad_inputs_schema to support grad_prompts
+            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+            grad_inputs_schema_with_prompts = (
+                requested_backends[0].args_schema * len(grads),
+                requested_backends[0].kwargs_schema,
+            )  # TODO generalize
+
+            # Serialize the overall grad_inputs
+            serialized_grad_inputs = [
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+            ]
+            # Split the serialized_grad_inputs for streaming and respond
+            output_split = [
+                part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            ]
+
+            async for part in as_aiter(*output_split):
+                yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
+        """Check that the first request to rpc_inference is valid"""
+        uids = (uids or "").split(CHAIN_DELIMITER)
+        if not uids:
+            raise RuntimeError("User did not provide any uids")
+        for uid in uids:
+            if uid not in self.module_backends:
+                raise RuntimeError(f"Remote peer does not serve {uid}")
+        return tuple(uids)
+
+    @contextlib.asynccontextmanager
+    async def _allocate_caches(
+        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+    ) -> Sequence[int]:
+        """Allocate memory caches for each transformer block, return cache handles"""
+        async with contextlib.AsyncExitStack() as stack:
+            handles = []
+            total_size = 0
+            backend = None
+            for backend in backends:
+                num_heads = backend.module.self_attention.num_heads
+                head_dim = backend.module.self_attention.head_dim
+
+                descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
+                # [key_or_value, batch_size, max_length, num_heads, head_dim]
+
+                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
+                total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
+
+            gib = 1024**3
+            if backend is not None:
+                cur_size = backend.memory_cache.current_size_bytes
+                max_size = backend.memory_cache.max_size_bytes
+                friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+                cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+            else:
+                cache_stats = f"cache stats n/a"
+            logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
+
+            yield handles
+
+    def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
+        friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
+        friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
+        friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
+
+        friendly_remote_id = "..." + str(context.remote_id)[-6:]
+
+        logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
+
+
+async def _rpc_forward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        if not is_dummy(prompt):
+            hidden_states[:, : prompt.shape[1]] += prompt
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            priority=priority,
+        )
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    # Serialize the overall output
+    return hidden_states
+
+
+async def _rpc_backward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    inputs, grad_outputs, prompts = flat_tensors
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = []
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        if not is_dummy(prompt):
+            inputs[:, : prompt.shape[1]] += prompt
+        inter_inputs.append(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+
+        assert isinstance(inputs, torch.Tensor)
+
+    if not is_dummy(prompts[-1]):
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
+
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 0 - 0
src/server/runtime.py → petals/server/runtime.py


+ 499 - 0
petals/server/server.py

@@ -0,0 +1,499 @@
+from __future__ import annotations
+
+import gc
+import multiprocessing as mp
+import random
+import threading
+import time
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import psutil
+import torch
+from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
+from hivemind.moe.server.layers import add_custom_models_from_file
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from petals.import BloomConfig, declare_active_modules
+from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from petals.dht_utils import get_remote_module_infos
+from petals.server import block_selection
+from petals.server.backend import TransformerBackend
+from petals.server.cache import MemoryCache
+from petals.server.handler import TransformerConnectionHandler
+from petals.server.throughput import get_host_throughput
+from petals.utils.convert_8bit import replace_8bit_linear
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class Server:
+    """
+    Runs ModuleContainer, periodically checks that the network is balanced,
+    restarts the ModuleContainer with other layers if the imbalance is significant
+    """
+
+    def __init__(
+        self,
+        *,
+        initial_peers: List[str],
+        prefix: Optional[str],
+        converted_model_name_or_path: str,
+        throughput: Union[float, str],
+        num_blocks: Optional[int] = None,
+        block_indices: Optional[str] = None,
+        num_handlers: int = 8,
+        min_batch_size: int = 1,
+        max_batch_size: int = 2048,
+        inference_max_length: int = 2048,
+        torch_dtype: str = "auto",
+        revision: str = "main",
+        cache_dir: Optional[str] = None,
+        attn_cache_size: Optional[int] = None,
+        alloc_timeout: float = 60,
+        device: Optional[Union[str, torch.device]] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
+        request_timeout: float = 3 * 60,
+        session_timeout: float = 30 * 60,
+        step_timeout: float = 5 * 60,
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
+        balance_quality: float = 0.75,
+        mean_balance_check_period: float = 60,
+        mean_block_selection_delay: float = 0.5,
+        use_auth_token: Optional[str] = None,
+        load_in_8bit: bool = False,
+        **kwargs,
+    ):
+        """Create a server with one or more bloom blocks. See run_server.py for documentation."""
+
+        self.converted_model_name_or_path = converted_model_name_or_path
+        self.num_handlers = num_handlers
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+        self.inference_max_length = inference_max_length
+        self.cache_dir = cache_dir
+        self.attn_cache_size = attn_cache_size
+        self.compression = compression
+        self.stats_report_interval, self.update_period = stats_report_interval, update_period
+        self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
+        self.use_auth_token = use_auth_token
+        self.load_in_8bit = load_in_8bit
+
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+
+        if prefix is None:
+            prefix = converted_model_name_or_path
+            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
+                f"Please specify --prefix manually when starting a server"
+            )
+            logger.info(f"Automatic dht prefix: {prefix}")
+        self.prefix = prefix
+
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+        self.expiration = expiration
+
+        self.request_timeout = request_timeout
+        self.session_timeout, self.step_timeout = session_timeout, step_timeout
+
+        self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
+        if initial_peers == PUBLIC_INITIAL_PEERS:
+            logger.info("Connecting to the public Petals swarm")
+        else:
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        self.device = device
+
+        self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
+
+        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+        if throughput in ["auto", "eval"]:
+            throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
+        self.throughput = throughput
+
+        if isinstance(torch_dtype, str):
+            torch_dtype = DTYPE_MAP[torch_dtype]
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        self.torch_dtype = torch_dtype
+
+        self.block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path,
+            use_auth_token=use_auth_token,
+            revision=revision,
+        )
+        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+
+        assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+        if block_indices is not None:
+            try:
+                first_block_index, last_block_index = block_indices.split(":")
+                first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
+            except Exception as e:
+                logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
+                raise
+            block_indices = range(first_block_index, last_block_index)
+        self.strict_block_indices, self.num_blocks = block_indices, num_blocks
+        self.balance_quality = balance_quality
+        self.mean_balance_check_period = mean_balance_check_period
+        self.mean_block_selection_delay = mean_block_selection_delay
+
+        self.stop = threading.Event()
+
+    def run(self):
+        while True:
+            block_indices = self._choose_blocks()
+            self.module_container = ModuleContainer.create(
+                dht=self.dht,
+                prefix=self.prefix,
+                converted_model_name_or_path=self.converted_model_name_or_path,
+                block_config=self.block_config,
+                memory_cache=self.memory_cache,
+                throughput=self.throughput,
+                block_indices=block_indices,
+                num_handlers=self.num_handlers,
+                min_batch_size=self.min_batch_size,
+                max_batch_size=self.max_batch_size,
+                inference_max_length=self.inference_max_length,
+                torch_dtype=self.torch_dtype,
+                cache_dir=self.cache_dir,
+                device=self.device,
+                compression=self.compression,
+                stats_report_interval=self.stats_report_interval,
+                update_period=self.update_period,
+                expiration=self.expiration,
+                request_timeout=self.request_timeout,
+                session_timeout=self.session_timeout,
+                step_timeout=self.step_timeout,
+                prefetch_batches=self.prefetch_batches,
+                sender_threads=self.sender_threads,
+                use_auth_token=self.use_auth_token,
+                load_in_8bit=self.load_in_8bit,
+                start=True,
+            )
+            try:
+                self.module_container.ready.wait()
+
+                while True:
+                    timeout = random.random() * 2 * self.mean_balance_check_period
+                    # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
+                    if self.stop.wait(timeout):
+                        return
+
+                    if self._should_choose_other_blocks():
+                        logger.info("Swarm is imbalanced, server will load other blocks")
+                        break  # Stop serving this set of modules
+            finally:
+                self.module_container.shutdown()
+
+            self._clean_memory_and_fds()
+
+    def _clean_memory_and_fds(self):
+        del self.module_container
+        gc.collect()  # In particular, this closes unused file descriptors
+
+        cur_proc = psutil.Process()
+        num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
+        logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
+
+    def _choose_blocks(self) -> List[int]:
+        if self.strict_block_indices is not None:
+            return self.strict_block_indices
+        assert self.num_blocks is not None
+
+        # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
+        # this delay decreases the probability of a race condition while choosing the best blocks to serve.
+        time.sleep(random.random() * 2 * self.mean_block_selection_delay)
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        return block_selection.choose_best_blocks(self.num_blocks, module_infos)
+
+    def _should_choose_other_blocks(self) -> bool:
+        if self.strict_block_indices is not None:
+            return False
+
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
+
+    def shutdown(self):
+        self.stop.set()
+
+        self.dht.shutdown()
+        self.dht.join()
+
+
+class ModuleContainer(threading.Thread):
+    """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
+
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        *,
+        dht: DHT,
+        prefix: str,
+        converted_model_name_or_path: str,
+        block_config: BloomConfig,
+        memory_cache: MemoryCache,
+        throughput: float,
+        block_indices: List[int],
+        min_batch_size: int,
+        max_batch_size: int,
+        torch_dtype: torch.dtype,
+        cache_dir: Optional[str],
+        device: Union[str, torch.device],
+        compression: CompressionType,
+        update_period: float,
+        expiration: Optional[float],
+        use_auth_token: Optional[str],
+        load_in_8bit: bool,
+        **kwargs,
+    ) -> ModuleContainer:
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        joining_announcer = ModuleAnnouncerThread(
+            module_uids,
+            dht,
+            ServerState.JOINING,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        joining_announcer.start()
+        logger.info(f"Announced that blocks {block_indices} are joining")
+
+        try:
+            blocks = {}
+            for module_uid, block_index in zip(module_uids, block_indices):
+                block = load_pretrained_block(
+                    converted_model_name_or_path,
+                    block_index,
+                    block_config,
+                    torch_dtype=torch_dtype,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                )
+
+                if load_in_8bit:
+                    block = replace_8bit_linear(block)
+
+                block = block.to(device)
+                for param in block.parameters():
+                    param.requires_grad = False
+
+                backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
+                blocks[module_uid] = TransformerBackend(
+                    module_uid,
+                    block,
+                    memory_cache=memory_cache,
+                    backend_dtype=backend_dtype,
+                    args_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
+                        ),
+                    ),
+                    kwargs_schema={},
+                    outputs_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
+                        ),
+                    ),
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
+        except:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+            declare_active_modules(
+                dht,
+                module_uids,
+                expiration_time=get_dht_time() + expiration,
+                state=ServerState.OFFLINE,
+                throughput=throughput,
+            )
+            logger.info(f"Announced that blocks {module_uids} are offline")
+            raise
+        else:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+
+        return cls(
+            dht,
+            blocks,
+            throughput=throughput,
+            device=device,
+            update_period=update_period,
+            expiration=expiration,
+            **kwargs,
+        )
+
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        *,
+        inference_max_length: int,
+        num_handlers: int,
+        throughput: float,
+        update_period: float,
+        expiration: Optional[float] = None,
+        request_timeout: float,
+        session_timeout: float,
+        step_timeout: float,
+        start: bool,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+        self.conn_handlers = [
+            TransformerConnectionHandler(
+                dht,
+                self.module_backends,
+                inference_max_length=inference_max_length,
+                request_timeout=request_timeout,
+                session_timeout=session_timeout,
+                step_timeout=step_timeout,
+            )
+            for _ in range(num_handlers)
+        ]
+        self.runtime = Runtime(self.module_backends, **kwargs)
+        self.online_announcer = ModuleAnnouncerThread(
+            list(self.module_backends.keys()),
+            dht,
+            ServerState.ONLINE,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    def run(self):
+        """
+        Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        self.online_announcer.start()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for handler in self.conn_handlers:
+            handler.run_in_background()
+
+        self.runtime.run()
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the container is ready to process requests.
+
+        Example
+        =======
+        >>> container.start()
+        >>> container.ready.wait(timeout=10)
+        >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the container, process-safe.
+        Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.online_announcer.stop.set()
+        self.online_announcer.join()
+
+        declare_active_modules(
+            self.dht,
+            self.module_backends.keys(),
+            expiration_time=get_dht_time() + self.expiration,
+            state=ServerState.OFFLINE,
+            throughput=self.throughput,
+        )
+        logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
+        self.ready.clear()
+
+        for handler in self.conn_handlers:
+            handler.shutdown()
+        logger.debug("Connection handlers terminated")
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        logger.debug(f"Shutting down pools")
+        for pool in self.runtime.pools:
+            if pool.is_alive():
+                pool.shutdown()
+
+        logger.debug(f"Shutting down runtime")
+        self.runtime.shutdown()
+
+        logger.info("Module container shut down succesfully")
+
+
+class ModuleAnnouncerThread(threading.Thread):
+    """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
+
+    def __init__(
+        self,
+        module_uids: List[str],
+        dht: DHT,
+        state: ServerState,
+        *,
+        throughput: float,
+        update_period: float = 30,
+        expiration: float,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.module_uids = module_uids
+        self.dht = dht
+        self.state = state
+        self.throughput = throughput
+        self.update_period = update_period
+        self.expiration = expiration
+        self.stop = threading.Event()
+
+    def run(self) -> None:
+        while True:
+            declare_active_modules(
+                self.dht,
+                self.module_uids,
+                expiration_time=get_dht_time() + self.expiration,
+                state=self.state,
+                throughput=self.throughput,
+            )
+            if self.stop.wait(self.update_period):
+                break

+ 0 - 0
src/server/task_pool.py → petals/server/task_pool.py


+ 0 - 0
src/server/task_prioritizer.py → petals/server/task_prioritizer.py


+ 127 - 0
petals/server/throughput.py

@@ -0,0 +1,127 @@
+import fcntl
+import json
+import os
+import subprocess
+import tempfile
+import time
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Dict, Union
+
+import torch
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from petals.import project_name
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig
+from petals.bloom.ops import build_alibi_tensor
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
+DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
+
+SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
+
+
+@dataclass
+class ThroughputInfo:
+    network_rps: float
+    device_rps: Dict[str, float]
+
+
+def get_host_throughput(
+    device: Union[str, torch.device],
+    force_eval: bool = False,
+    cache_path: str = DEFAULT_CACHE_PATH,
+    lock_path: str = DEFAULT_LOCK_PATH,
+) -> float:
+    # We only keep the device type, assuming that the throughput is similar among all host's GPUs
+    device = torch.device(device).type
+
+    # We use the system-wide lock since only one process at a time can measure the host throughput
+    os.makedirs(lock_path.parent, exist_ok=True)
+    with open(lock_path, "wb") as lock_fd:
+        logger.info("Loading throughput info")
+        fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
+        # The OS will release the lock when lock_fd is closed or the process is killed
+
+        info = None
+        try:
+            if not force_eval and os.path.exists(cache_path):
+                with open(cache_path) as cache_fd:
+                    info = ThroughputInfo(**json.load(cache_fd))
+                if device not in info.device_rps:
+                    force_eval = True
+        except Exception:
+            logger.exception(f"Failed to read throughput info from {cache_path}")
+            force_eval = True
+
+        if force_eval or info is None:
+            info = measure_throughput_info()
+            try:
+                os.makedirs(cache_path.parent, exist_ok=True)
+                with open(cache_path, "w") as cache_fd:
+                    json.dump(asdict(info), cache_fd)
+            except Exception:
+                logger.exception(f"Failed to save throughput info in {cache_path}")
+
+    throughput = min(info.network_rps, info.device_rps[device])
+    return throughput
+
+
+def measure_throughput_info() -> ThroughputInfo:
+    logger.info(
+        "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
+    )
+
+    # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
+    config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
+
+    network_rps = measure_network_rps(config)
+
+    device_rps = {"cpu": measure_device_rps("cpu", config)}
+    if torch.cuda.is_available():
+        device_rps["cuda"] = measure_device_rps("cuda", config)
+
+    return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
+
+
+def measure_network_rps(config: BloomConfig) -> float:
+    proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
+    if proc.returncode != 0:
+        raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
+    network_info = json.loads(proc.stdout)
+
+    bits_per_request = config.hidden_size * 32
+    network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
+
+    logger.info(
+        f"Network throughput: "
+        f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
+        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
+        f"{network_rps:.2f} RPS"
+    )
+    return network_rps
+
+
+def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
+    with torch.inference_mode():
+        block = BloomBlock(config, layer_index).to(device)
+        cache = None
+        elapsed = 0
+        for i in range(n_steps):
+            dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
+            alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
+
+            start_time = time.perf_counter()
+            _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
+            elapsed += time.perf_counter() - start_time
+        device_rps = n_steps / elapsed
+
+    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
+    logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
+
+    return device_rps

+ 0 - 0
src/__init__.py → petals/src/__init__.py


+ 0 - 0
src/bloom/__init__.py → petals/src/bloom/__init__.py


+ 0 - 0
src/bloom/block.py → petals/src/bloom/block.py


+ 0 - 0
src/bloom/from_pretrained.py → petals/src/bloom/from_pretrained.py


+ 0 - 0
src/bloom/model.py → petals/src/bloom/model.py


+ 246 - 0
petals/src/bloom/ops.py

@@ -0,0 +1,246 @@
+"""
+Utility operations used in the the BLOOM model
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+import math
+
+import torch
+import torch.autograd
+import torch.nn.functional as F
+from torch import nn
+
+
+def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
+    """Split a tensor along its last dimension.
+
+    Args:
+        tensor: ([`torch.tensor`], *required*):
+            input tensor to split
+        num_partitions ([`int`], *required*):
+            number of partitions to split the tensor
+        contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
+            If True, make each chunk contiguous in memory.
+    """
+    # Get the size and dimension.
+    last_dim = tensor.dim() - 1
+    numerator, denominator = tensor.size()[last_dim], num_partitions
+    if not (numerator % denominator == 0):
+        raise ValueError(f"{numerator} is not divisible by {denominator}")
+    last_dim_size = numerator // denominator
+    # Split.
+    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+    # Note: torch.split does not create contiguous tensors by default.
+    if contiguous_split_chunks:
+        return tuple(chunk.contiguous() for chunk in tensor_list)
+
+    return tensor_list
+
+
+def attention_mask_func(attention_scores, attention_mask, causal_mask):
+    if attention_mask.dtype == torch.bool:
+        attention_mask_bool = ~attention_mask
+    else:
+        attention_mask_bool = (1 - attention_mask).bool()
+
+    query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
+    padded_causal_mask = (
+        attention_mask_bool[:, None, key_length - query_length : key_length, None]
+        + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
+    ).bool()
+    padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
+    # Make use of floats
+    return (
+        attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
+        padded_causal_mask,
+    )
+
+
+def build_alibi_tensor(
+    max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
+) -> torch.Tensor:
+    """
+    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+    `softmax(l+a) = softmax(l)`. Based on
+    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+    Args:
+    Returns tensor shaped (n_head, 1, max_seq_len)
+        max_seq_len: (`int`, *required*):
+            max sequence length
+        n_head: (`int`, *required*):
+            number of heads
+        dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+            dtype of the output tensor
+        device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
+            device of the output alibi tensor
+    """
+    closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
+    base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
+    powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != n_head:
+        extra_base = torch.tensor(
+            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
+        )
+        num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
+        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
+        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+    lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
+    return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
+
+
+def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
+    """
+    Args:
+    Pre-process the alibi tensor for padding.
+        alibi: ([`torch.tensor`], *required*):
+            alibi tensor to pre-process
+        attention_mask: ([`torch.tensor`], *required*):
+            attention mask to pre-process
+    """
+    assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
+    unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
+    # ^-- [batch, max_len], values correspond to element indices after removing padding
+    # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
+    alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
+    return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
+
+
+def dropout_add(x, residual, prob, training):
+    """
+    Dropout add function
+
+    Args:
+        x (`torch.tensor`, *required*):
+            input tensor
+        residual (`torch.tensor`, *rquired*):
+            esidual tensor
+        prob (`float`, *required*):
+            dropout probability
+        training (`bool`, *required*):
+            training mode
+    """
+    out = nn.functional.dropout(x, p=prob, training=training)
+    out = residual + out
+    return out
+
+
+def bloom_gelu_forward(x):
+    """
+    Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
+    make the model jitable.
+
+    Args:
+        x (`torch.tensor`, *required*):
+            input hidden states
+    """
+    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+
+def bloom_gelu_back(g, x):
+    """
+    gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
+    0.3989423 * x * torch.exp(-0.5 * x * x)
+
+    Args:
+        g (`torch.tensor`, *required*):
+            gradient output tensor
+        x (`torch.tensor`, *required*):
+            input tensor
+    """
+    x = x[0]  # x is a tuple of 1 element, needs to unpack it first
+    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+    return ff * g
+
+
+class GeLUFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        ctx.save_for_backward(input)
+        return bloom_gelu_forward(input)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_tensors
+        tmp = bloom_gelu_back(grad_output, input)
+        return tmp
+
+
+class BloomGelu(nn.Module):
+    """
+    BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
+    torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
+    copied from Megatron-DeepSpeed code and adapted for our needs
+
+    See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
+
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        if self.training:
+            return GeLUFunction.apply(x)
+        else:
+            return bloom_gelu_forward(x)
+
+
+class BloomScaledSoftmax(nn.Module):
+    """
+    fused operation: scaling + mask + softmax
+
+    Args:
+        input_in_fp16 (`bool`, *required*):
+            flag to indicate if input in fp16 data format.
+        input_in_bf16 (`bool`, *required*):
+            flag to indicate if input in bf16 data format.
+        scaled_masked_softmax_fusion (`bool`, *required*):
+            flag to indicate user want to use softmax fusion
+        mask_func (`function`, *required*):
+            mask function to be applied.
+        softmax_in_fp32 (`bool`, *required*):
+            if true, softmax in performed at fp32 precision.
+        scale (`float`, *required*):
+            scaling factor used in input tensor scaling.
+    """
+
+    def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
+        super().__init__()
+        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
+        self.mask_func = mask_func
+        self.softmax_in_fp32 = softmax_in_fp32
+        self.scale = scale
+
+        if not (self.scale is None or softmax_in_fp32):
+            raise ValueError("softmax should be in fp32 when scaled")
+
+    def forward(self, input, mask, max_positions):
+        input_dtype = input.dtype
+        input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
+        softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
+
+        if self.scale is not None:
+            input = input * self.scale
+
+        if mask is None:
+            mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
+
+        mask = mask.to(input.device)
+        causal_mask = (
+            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
+            .view(1, 1, max_positions, max_positions)
+            .to(input.device)
+        )
+        mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
+        probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
+
+        if input_in_16bit and self.softmax_in_fp32:
+            probs = probs.to(dtype=input_dtype)
+
+        return probs

+ 0 - 0
src/client/__init__.py → petals/src/client/__init__.py


+ 0 - 0
src/client/inference_session.py → petals/src/client/inference_session.py


+ 0 - 0
src/client/remote_forward_backward.py → petals/src/client/remote_forward_backward.py


+ 0 - 0
src/client/remote_generation.py → petals/src/client/remote_generation.py


+ 0 - 0
src/client/remote_model.py → petals/src/client/remote_model.py


+ 0 - 0
src/client/remote_sequential.py → petals/src/client/remote_sequential.py


+ 0 - 0
src/client/sequence_manager.py → petals/src/client/sequence_manager.py


+ 0 - 0
src/client/sequential_autograd.py → petals/src/client/sequential_autograd.py


+ 14 - 0
petals/src/client/spending_policy.py

@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+
+from hivemind.proto.runtime_pb2 import ExpertRequest
+
+
+class SpendingPolicyBase(ABC):
+    @abstractmethod
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        pass
+
+
+class NoSpendingPolicy(SpendingPolicyBase):
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        return 0.0

+ 8 - 0
petals/src/constants.py

@@ -0,0 +1,8 @@
+PUBLIC_INITIAL_PEERS = [
+    "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
+    "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
+    "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
+    "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
+    "/dns/bootstrap3.petals.ml/tcp/31339/p2p/QmX82nfE57CSkNgyEC7pPMPBzjcFLLJXdHhvp1AXKVPvJD",
+    "/dns6/bootstrap3.petals.ml/tcp/31339/p2p/QmX82nfE57CSkNgyEC7pPMPBzjcFLLJXdHhvp1AXKVPvJD",
+]

+ 41 - 0
petals/src/data_structures.py

@@ -0,0 +1,41 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict
+
+from hivemind import PeerID
+
+ModuleUID = str
+UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
+CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
+
+
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+@dataclass
+class ServerInfo:
+    state: ServerState
+    throughput: float
+
+
+@dataclass
+class RemoteModuleInfo:
+    """A remote module that is served by one or more servers"""
+
+    uid: ModuleUID
+    servers: Dict[PeerID, ServerInfo]
+
+
+@dataclass
+class RemoteSpanInfo:
+    """A chain of remote blocks served by one specific remote peer"""
+
+    start: int
+    end: int
+    peer_id: PeerID
+
+
+RPCInfo = Dict[str, Any]

+ 0 - 0
src/dht_utils.py → petals/src/dht_utils.py


+ 0 - 0
src/utils/__init__.py → petals/src/server/__init__.py


+ 0 - 0
src/server/backend.py → petals/src/server/backend.py


+ 0 - 0
src/server/block_selection.py → petals/src/server/block_selection.py


+ 148 - 0
petals/src/server/cache.py

@@ -0,0 +1,148 @@
+"""
+A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
+
+For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
+
+"""
+import asyncio
+import contextlib
+import ctypes
+import multiprocessing as mp
+import os
+import time
+from typing import AsyncContextManager, Dict, Optional, Union
+
+import hivemind
+import torch
+from hivemind import use_hivemind_log_handler
+from hivemind.utils import TensorDescriptor, get_logger
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+Handle = int
+
+
+class MemoryCache:
+    """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
+
+    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+        self.alloc_timeout = alloc_timeout
+        self.device = device
+        self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
+        self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
+        self.runtime_pid = os.getpid()
+
+        self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
+        self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._lock_acquire_memory = mp.Lock()
+        self._memory_freed_event = mp.Event()
+
+    @property
+    def current_size_bytes(self) -> int:
+        return self._current_size.value
+
+    @current_size_bytes.setter
+    def current_size_bytes(self, value: int):
+        self._current_size.value = value
+
+    @property
+    def handle_counter(self) -> int:
+        return self._handle_counter.value
+
+    @handle_counter.setter
+    def handle_counter(self, value: int):
+        self._handle_counter.value = value
+
+    @contextlib.asynccontextmanager
+    async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
+        """
+        Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
+
+        :param descr: allocate a tensor of this size, dtype, etc
+
+        :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
+        Furthermore, it can be called concurrently with at most one use_cache call in runtime.
+        """
+        assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
+        assert descr.device is None and descr
+        allocated_handle = None
+        allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        loop = asyncio.get_event_loop()
+        try:
+            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
+                if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+                    await loop.run_in_executor(
+                        None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout
+                    )
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+                    allocated_handle = int(self.handle_counter)
+                    self.current_size_bytes += allocated_size_bytes
+                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, descr))
+
+            yield allocated_handle
+        finally:
+            if allocated_handle is not None:
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
+                    self.current_size_bytes -= allocated_size_bytes
+                self._memory_freed_event.set()
+
+    def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
+        # note: this function should only be called inside _lock_acquire_memory!
+        if allocated_size > self.max_size_bytes:
+            raise AllocationFailed(
+                f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
+            )
+        deadline = None if timeout is None else time.perf_counter() + timeout
+        while self.current_size_bytes + allocated_size > self.max_size_bytes:
+            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            if not self._memory_freed_event.wait(remaining_time):
+                raise AllocationFailed(
+                    f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
+                )
+            self._memory_freed_event.clear()
+
+    @contextlib.contextmanager
+    def use_cache(self, handle: Handle) -> torch.Tensor:
+        """
+        Return a tensor that was previously allocated with try_allocate_cache,
+
+        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+        However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
+        """
+        assert os.getpid() == self.runtime_pid
+        # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
+
+        with self._lock_metadata:
+            if self._allocated_tensors is None:
+                self._allocated_tensors = {}
+
+            # read creation/deletion requests from connection handlers
+            for i in range(int(self._pending_messages.value)):
+                recv_handle, recv_data = self._pipe_recv.recv()
+                self._pending_messages.value -= 1
+                if isinstance(recv_data, TensorDescriptor):
+                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
+                elif recv_data is None:
+                    if recv_handle not in self._allocated_tensors:
+                        logger.warning(
+                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
+                        )
+                    self._allocated_tensors.pop(recv_handle, None)
+                else:
+                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
+
+        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+        yield self._allocated_tensors[handle]
+
+
+class AllocationFailed(Exception):
+    pass

+ 0 - 0
src/server/handler.py → petals/src/server/handler.py


+ 198 - 0
petals/src/server/runtime.py

@@ -0,0 +1,198 @@
+import multiprocessing as mp
+import multiprocessing.pool
+import threading
+from collections import defaultdict
+from itertools import chain
+from queue import SimpleQueue
+from selectors import EVENT_READ, DefaultSelector
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple, Optional
+
+import torch
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+from prefetch_generator import BackgroundGenerator
+
+logger = get_logger(__name__)
+
+
+class Runtime(threading.Thread):
+    """
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
+    Runtime is usually created and managed by Server, humans need not apply.
+
+    For debugging, you can start runtime manually with .start() or .run()
+
+    >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
+    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
+    >>> runtime.ready.wait()  # await for runtime to load all blocks on device and create request pools
+    >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
+    >>> print("Returned:", future.result())
+    >>> runtime.shutdown()
+
+    :param module_backends: a dict [block uid -> ModuleBackend]
+    :param prefetch_batches: form up to this many batches in advance
+    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
+    :param device: if specified, moves all blocks and data to this device via .to(device=device).
+      If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
+
+    :param stats_report_interval: interval to collect and log statistics about runtime performance
+    """
+
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
+
+    def __init__(
+        self,
+        module_backends: Dict[str, ModuleBackend],
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
+        device: torch.device = None,
+        stats_report_interval: Optional[int] = None,
+    ):
+        super().__init__()
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
+        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.shutdown_trigger = mp.Event()
+        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+
+        self.stats_report_interval = stats_report_interval
+        if self.stats_report_interval is not None:
+            self.stats_reporter = StatsReporter(self.stats_report_interval)
+
+    def run(self):
+        for pool in self.pools:
+            if not pool.is_alive():
+                pool.start()
+        if self.device is not None:
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
+
+        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
+            try:
+                self.ready.set()
+                if self.stats_report_interval is not None:
+                    self.stats_reporter.start()
+                logger.info("Started")
+
+                batch_iterator = self.iterate_minibatches_from_pools()
+                if self.prefetch_batches > 0:
+                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
+
+                for pool, batch_index, batch in batch_iterator:
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
+
+                    start = time()
+                    try:
+                        outputs = pool.process_func(*batch)
+                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+
+                        batch_processing_time = time() - start
+
+                        batch_size = outputs[0].size(0)
+                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+
+                        if self.stats_report_interval is not None:
+                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+                    except KeyboardInterrupt:
+                        raise
+                    except BaseException as exception:
+                        logger.exception(f"Caught {exception}, attempting to recover")
+                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
+
+            finally:
+                if not self.shutdown_trigger.is_set():
+                    self.shutdown()
+
+    def shutdown(self):
+        """Gracefully terminate a running runtime."""
+        logger.info("Shutting down")
+        self.ready.clear()
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        logger.debug("Terminating pools")
+        for pool in self.pools:
+            if pool.is_alive():
+                pool.shutdown()
+        logger.debug("Pools terminated")
+
+        # trigger background thread to shutdown
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+        self.shutdown_trigger.set()
+
+    def iterate_minibatches_from_pools(self, timeout=None):
+        """
+        Chooses pool according to priority, then copies exposed batch and frees the buffer
+        """
+        with DefaultSelector() as selector:
+            for pool in self.pools:
+                selector.register(pool.batch_receiver, EVENT_READ, pool)
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+
+            while True:
+                # wait until at least one batch_receiver becomes available
+                logger.debug("Waiting for inputs from task pools")
+                ready_fds = selector.select()
+                ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
+
+                logger.debug("Choosing the pool with first priority")
+
+                pool = min(ready_objects, key=lambda pool: pool.priority)
+
+                logger.debug(f"Loading batch from {pool.name}")
+                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                logger.debug(f"Loaded batch from {pool.name}")
+                yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
+
+
+class StatsReporter(threading.Thread):
+    def __init__(self, report_interval: int):
+        super().__init__()
+        self.report_interval = report_interval
+        self.stop = threading.Event()
+        self.stats_queue = SimpleQueue()
+
+    def run(self):
+        while not self.stop.wait(self.report_interval):
+            pool_batch_stats = defaultdict(list)
+            while not self.stats_queue.empty():
+                pool_uid, batch_stats = self.stats_queue.get()
+                pool_batch_stats[pool_uid].append(batch_stats)
+
+            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
+            for pool_uid, pool_stats in pool_batch_stats.items():
+                total_batches = len(pool_stats)
+                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+                batches_to_time = total_batches / total_time
+                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
+
+                examples_to_time = total_examples / total_time
+                example_performance = f"{examples_to_time:.2f} " + (
+                    "examples/s" if examples_to_time > 1 else "s/example"
+                )
+
+                logger.info(
+                    f"{pool_uid}: "
+                    f"{total_batches} batches ({batch_performance}), "
+                    f"{total_examples} examples ({example_performance}), "
+                    f"avg batch size {avg_batch_size:.2f}"
+                )
+
+    def report_stats(self, pool_uid, batch_size, processing_time):
+        batch_stats = BatchStats(batch_size, processing_time)
+        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 0 - 0
src/server/server.py → petals/src/server/server.py


+ 178 - 0
petals/src/server/task_pool.py

@@ -0,0 +1,178 @@
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+from dataclasses import dataclass, field
+from queue import PriorityQueue
+from typing import Any, Generator, List, Optional, Sequence, Tuple
+
+import torch
+from hivemind import MPFuture, get_logger, use_hivemind_log_handler
+from hivemind.moe.server.task_pool import TaskPoolBase
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@dataclass(order=True, frozen=True)
+class Task:
+    priority: float
+    time_submitted: float
+    future: MPFuture = field(compare=False)
+    args: Sequence[torch.Tensor] = field(compare=False)
+
+    @property
+    def uid(self) -> int:
+        return self.future._uid
+
+
+class PrioritizedTaskPool(TaskPoolBase):
+    """
+    Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
+    returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
+    A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
+
+    :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
+      This would require grouping requests of different length.
+
+    :param process_func: function to be applied to every formed batch; called by Runtime
+        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
+    :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+         Measured in the total number of tokens (i.e. batch size * sequence length)
+
+    :param name: pool name, used for logging
+    :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param start: if True, start automatically at the end of __init__
+    """
+
+    def __init__(
+        self,
+        process_func: callable,
+        max_batch_size: int,
+        name: str,
+        min_batch_size=1,
+        daemon=True,
+        start=False,
+    ):
+        super().__init__(process_func, daemon=daemon, name=name)
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+
+        self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
+        self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
+
+        self._prioritizer_thread = threading.Thread(
+            name=self.name + "_prioritizer",
+            target=self._prioritize_tasks,
+            args=[self.submitted_tasks, self._ordered_tasks],
+            daemon=True,
+        )
+        self._dispatched_tasks = {}
+        self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
+        self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
+        self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
+
+        self._stop = mp.Event()
+        if start:
+            self.start()
+
+    @staticmethod
+    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+        """Read tasks from incoming queue and put them into a local priority queue"""
+        while True:
+            task = submitted_tasks.get()
+            if task is None:
+                logger.debug("Shutting down prioritizer thread")
+                break
+
+            ordered_tasks.put(task, block=True)
+
+    def start(self):
+        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
+        self._prioritizer_thread.start()
+        super().start()
+
+    def shutdown(self, timeout: float = 3):
+        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
+        self._stop.set()
+
+        self.join(timeout)
+        if self.is_alive():
+            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
+            self.terminate()
+
+    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
+        """Add task to this pool's queue, return Future for its output"""
+        task = Task(priority, time.monotonic(), MPFuture(), args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
+            task.future.set_exception(exc)
+        else:
+            self.submitted_tasks.put(task)
+            self.batch_sender.send(None)  # use this pipe to count the number of unfinished batches
+            if (task.priority, task.time_submitted) < self.priority:
+                self.priority = (task.priority, task.time_submitted)
+        return task.future
+
+    def get_task_size(self, task: Task) -> int:
+        """compute task processing complexity; defaults to the total number of tokens"""
+        if task.args and task.args[0].ndim >= 2:
+            return task.args[0].shape[0] * task.args[0].shape[1]
+        return 1
+
+    def load_batch_to_runtime(
+        self, timeout: Optional[float] = None, device: Optional[torch.device] = None
+    ) -> Tuple[Any, List[torch.Tensor]]:
+        """receive next batch of arrays"""
+        task = self._ordered_tasks.get(block=True, timeout=timeout)
+        batch_inputs = [
+            tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
+        ]
+        self._dispatched_tasks[task.uid] = task
+        self.batch_receiver.recv()  # reduce the number of active batches
+        if not self._ordered_tasks.empty():
+            first_remaining_task: Task = self._ordered_tasks.queue[0]
+            self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
+        return task.uid, batch_inputs
+
+    def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
+        """send results for a processed batch, previously loaded through load_batch_to_runtime"""
+        batch_outputs = [
+            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+            for tensor in batch_outputs
+        ]
+
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+            )
+        else:
+            task.future.set_result(batch_outputs)
+
+    def send_exception_from_runtime(self, uid: int, exception: BaseException):
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; "
+                f"Could not set exception {exception}"
+            )
+        else:
+            task.future.set_exception(exception)
+
+    def run(self, *args, **kwargs):
+        self._stop.wait()
+
+    @property
+    def empty(self):
+        return not self.batch_receiver.poll()
+
+    @property
+    def priority(self) -> Tuple[float, float]:
+        """The priority of this pool equals the (priority, timestamp) of the most important task in it."""
+        return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
+
+    @priority.setter
+    def priority(self, item: Tuple[float, float]):
+        assert len(item) == 2
+        self._priority.value = float(item[0])
+        self._oldest_undispatched_timestamp.value = float(item[1])

+ 20 - 0
petals/src/server/task_prioritizer.py

@@ -0,0 +1,20 @@
+from abc import ABC, abstractmethod
+
+import torch
+from hivemind.moe.server.task_pool import Task
+
+
+class TaskPrioritizerBase(ABC):
+    """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
+
+    @abstractmethod
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+        """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
+        pass
+
+
+class DummyTaskPrioritizer(TaskPrioritizerBase):
+    """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
+
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+        return 0.0

+ 0 - 0
src/server/throughput.py → petals/src/server/throughput.py


+ 0 - 0
petals/src/utils/__init__.py


+ 0 - 0
src/utils/convert_8bit.py → petals/src/utils/convert_8bit.py


+ 0 - 0
src/utils/generation_algorithms.py → petals/src/utils/generation_algorithms.py


+ 0 - 0
src/utils/generation_constraints.py → petals/src/utils/generation_constraints.py


+ 0 - 0
src/utils/misc.py → petals/src/utils/misc.py


+ 0 - 0
petals/utils/__init__.py


+ 41 - 0
petals/utils/convert_8bit.py

@@ -0,0 +1,41 @@
+import os
+
+import bitsandbytes as bnb
+import torch
+
+PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 1)))
+
+
+def replace_8bit_linear(model, threshold=6.0):
+    """
+    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
+    be kept as a `torch.nn.Linear` module.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
+            model._modules[n] = bnb.nn.Linear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+                memory_efficient_backward=PETALS_8BIT_BACKWARD,
+            )
+            model._modules[n].weight = bnb.nn.Int8Params(
+                module.weight.data, requires_grad=False, has_fp16_weights=False
+            ).to(module.weight.dtype)
+    return model

+ 121 - 0
petals/utils/generation_algorithms.py

@@ -0,0 +1,121 @@
+from abc import ABC
+from typing import Tuple
+
+import torch
+
+TokenIds = torch.Tensor
+HypoIds = torch.Tensor
+
+
+class DecodingAlgorithm(ABC):
+    """
+    An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
+        """
+        pass
+
+
+class GreedyAlgorithm(DecodingAlgorithm):
+    """
+    The simpliest algorithm for decoding. It selects the most probable token.
+    """
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
+        """
+        return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
+
+
+class SamplingAlgorithm(DecodingAlgorithm):
+    def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
+        :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
+        """
+        logits[indices_to_remove] = -float("Inf")
+        probs = torch.softmax(logits / self.temperature, -1)
+        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+class TopKAlgorithm(SamplingAlgorithm):
+    def __init__(self, top_k: int, temperature: float = 1.0) -> None:
+        self.top_k = top_k
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
+        return self.sample(logits, indices_to_remove)
+
+
+class NucleusAlgorithm(SamplingAlgorithm):
+    def __init__(self, top_p: float, temperature: float = 1.0) -> None:
+        self.top_p = top_p
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.softmax(sorted_logits / self.temperature, -1)
+        cumulative_probs = torch.cumsum(probs, dim=-1)
+        sorted_indices_to_remove = cumulative_probs > self.top_p
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = False
+        indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
+        indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
+        return self.sample(logits, indices_to_remove)
+
+
+class BeamSearchAlgorithm(DecodingAlgorithm):
+    def __init__(self, num_beams: int, batch_size: int) -> None:
+        self.num_beams = num_beams
+        self._cur_num_beams = 1
+        self.batch_size = batch_size
+
+        self._batch_beams = [list() for _ in range(batch_size)]
+
+    def __call__(self, logits: torch.Tensor):
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.log_softmax(sorted_logits, -1)
+
+        if len(self._batch_beams[0]) > 0:
+            for batch_idx in range(self.batch_size):
+                new_beams = []
+                cur_beams = self._batch_beams[batch_idx]
+                for beam_idx in range(len(cur_beams)):
+                    probs_idx = batch_idx + beam_idx * self.batch_size
+                    new_beam = cur_beams[beam_idx]
+                    for hypo_idx in range(self.num_beams):
+                        new_beams.append(
+                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
+                        )
+                self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
+        else:
+            for batch_idx in range(self.batch_size):
+                for beam_idx in range(self.num_beams):
+                    self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
+
+        return_hypos = []
+        return_tokens = []
+        for batch_idx in range(self.batch_size):
+            cur_beam = self._batch_beams[batch_idx]
+            return_hypos.append(list())
+            return_tokens.append(list())
+            for beam in cur_beam:
+                beam_idx = beam[1] // self.num_beams
+                hypo_idx = batch_idx + beam_idx * self.batch_size
+                token_idx = beam[1] % self.num_beams
+                return_hypos[-1].append(hypo_idx)
+                return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
+        return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
+        return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
+
+        return torch.tensor(return_tokens), torch.tensor(return_hypos)

+ 51 - 0
petals/utils/generation_constraints.py

@@ -0,0 +1,51 @@
+from abc import ABC
+
+import torch
+
+
+class ABCBloomConstraint(ABC):
+    """
+    Base class of all kind of decoding constraints. It can be used to implement a new constraint.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        """
+        This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
+        :param tokens_id: The token id of the last choosen token.
+        :param logits: The logits from the Bloom model.
+        :param hypo_ids: The hypothesis ids of the last tokens.
+        """
+        pass
+
+
+class EosConstraint(ABCBloomConstraint):
+    """
+    This constrained repeats EOS token if it was generated on the previous step.
+    Args:
+        prefix: The prefix of the sequence.
+        eos_token_id: The id of the end of sentence token.
+        pad_token_id: The id of the padding token.
+        min_logits: The minimum logits that can be generated. Default: -1e6.
+    """
+
+    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
+        self.eos_token_id = eos_token_id
+        self.min_logits = min_logits
+        self.past_tokens = None
+
+        self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        if self.past_tokens is not None:
+            mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
+            logits += self.min_logits * mask
+            logits[mask[:, 0], self.eos_token_id] = 0
+
+        if tokens_id is not None:
+            self.past_tokens = tokens_id
+            self.wait_until_starting -= 1
+
+        return logits

+ 7 - 0
petals/utils/misc.py

@@ -0,0 +1,7 @@
+import torch
+
+DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
+
+
+def is_dummy(tensor: torch.Tensor):
+    return tensor.numel() == 0

+ 5 - 5
tests/test_block_exact_match.py

@@ -8,11 +8,11 @@ from hivemind import P2PHandlerError
 from test_utils import *
 
 import src
-from src import DistributedBloomConfig
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteTransformerBlock
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import get_remote_module
+from petals.import DistributedBloomConfig
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client.remote_sequential import RemoteTransformerBlock
+from petals.data_structures import UID_DELIMITER
+from petals.dht_utils import get_remote_module
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_chained_calls.py

@@ -10,9 +10,9 @@ import torch
 from test_utils import *
 
 import src
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
-from src.dht_utils import get_remote_sequence
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client.remote_sequential import RemoteSequential
+from petals.dht_utils import get_remote_sequence
 
 
 @pytest.mark.forked

+ 2 - 2
tests/test_full_model.py

@@ -5,8 +5,8 @@ from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
 from transformers.generation_utils import BeamSearchScorer
 
-from src.bloom.model import BloomForCausalLM
-from src.client.remote_model import DistributedBloomForCausalLM
+from petals.bloom.model import BloomForCausalLM
+from petals.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 2 - 2
tests/test_priority_pool.py

@@ -4,8 +4,8 @@ import time
 import pytest
 import torch
 
-from src.server.runtime import Runtime
-from src.server.task_pool import PrioritizedTaskPool
+from petals.server.runtime import Runtime
+from petals.server.task_pool import PrioritizedTaskPool
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_remote_sequential.py

@@ -3,9 +3,9 @@ import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
-from src import RemoteSequential
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_model import DistributedBloomConfig
+from petals.import RemoteSequential
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)