瀏覽代碼

Add local tensor-parallel fwd/bwd (#143)

This pull request adds an option to run Petals server on multiple local GPUs. It uses https://github.com/BlackSamorez/tensor_parallel

- 8bit approximation error same as in main (mean~=2% q0.9~=5%)
    - TP=1, 2, 3 (see screenshots above)
- forward, grad w.r.t. input and inference exact match with main with TP=1
- `>=`80% GPU utilization with 3x 1080ti, batch = 8 tokens
- throughput measured with and without TP
- TP on 1080Tis has near-linear speedup comparable to the benchmarks (see first message)


Co-authored-by: Iaroslav Lisniak <yalisnyak@nes.ru>
Co-authored-by: Andrei Panferov <andrei@blacksamorez.ru>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 年之前
父節點
當前提交
ae9e71fe8e

+ 3 - 3
.github/workflows/run-tests.yaml

@@ -86,16 +86,16 @@ jobs:
 
 
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
 
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
           SERVER3_PID=$!
           SERVER3_PID=$!
 
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
           SERVER4_PID=$!
           SERVER4_PID=$!
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu  --torch_dtype float32 &> server5.log &
           SERVER5_PID=$!
           SERVER5_PID=$!
 
 
           tail -n 100 -f server*.log &
           tail -n 100 -f server*.log &

+ 1 - 0
setup.cfg

@@ -39,6 +39,7 @@ install_requires =
     protobuf>=3.20.3,<4.0dev
     protobuf>=3.20.3,<4.0dev
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
     hivemind==1.1.3
     hivemind==1.1.3
+    tensor_parallel==1.0.23
     humanfriendly
     humanfriendly
     async-timeout>=4.0.2
     async-timeout>=4.0.2
 
 

+ 5 - 1
src/petals/cli/run_server.py

@@ -129,8 +129,12 @@ def main():
 
 
     parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
     parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
     parser.add_argument('--load_in_8bit', type=str, default=None,
     parser.add_argument('--load_in_8bit', type=str, default=None,
-                        help="Convert the loaded model into mixed-8bit quantized model. "
+                        help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
                              "Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
                              "Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
+    parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
+                        help=
+                        "Split each block between the specified GPUs such that each device holds a portion of every "
+                        "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
 
 
     parser.add_argument("--skip_reachability_check", action='store_true',
     parser.add_argument("--skip_reachability_check", action='store_true',
                         help="Skip checking this server's reachability via health.petals.ml "
                         help="Skip checking this server's reachability via health.petals.ml "

+ 12 - 1
src/petals/data_structures.py

@@ -1,9 +1,14 @@
+from __future__ import annotations
+
+import dataclasses
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 
 
 from hivemind import PeerID
 from hivemind import PeerID
 
 
+from petals.server.memory_cache import Handle
+
 ModuleUID = str
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@@ -39,3 +44,9 @@ class RemoteSpanInfo:
 
 
 
 
 RPCInfo = Dict[str, Any]
 RPCInfo = Dict[str, Any]
+
+
+@dataclasses.dataclass(frozen=True)
+class InferenceMetadata:
+    prefix_length: int
+    cache_handles: Tuple[Handle, ...]

+ 69 - 37
src/petals/server/backend.py

@@ -1,12 +1,19 @@
 """Code for serving bloom blocks via hivemind-server"""
 """Code for serving bloom blocks via hivemind-server"""
+from __future__ import annotations
+
+from itertools import chain
 from typing import Any, Dict, Sequence, Tuple
 from typing import Any, Dict, Sequence, Tuple
 
 
 import torch
 import torch
-from hivemind import BatchTensorDescriptor
+from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
+from tensor_parallel import TensorParallel
+from tensor_parallel.tensor_parallel import PerDeviceTensors
+from transformers import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
 
 
-from petals.bloom.block import WrappedBloomBlock
+from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 from petals.utils.misc import is_dummy
@@ -17,9 +24,10 @@ logger = get_logger(__file__)
 class TransformerBackend(ModuleBackend):
 class TransformerBackend(ModuleBackend):
     """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
     """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
 
 
-    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
+    def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
-        assert isinstance(self.module, WrappedBloomBlock)
+        assert isinstance(self.module, TensorParallel)
+        self.config = config
         self.memory_cache = memory_cache
         self.memory_cache = memory_cache
         for name, param in self.module.named_parameters():
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
@@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend):
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
 
         max_batch_size = self.forward_pool.max_batch_size
         max_batch_size = self.forward_pool.max_batch_size
+        device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
+            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
         )
         )
         self.forward_pool = PrioritizedTaskPool(
         self.forward_pool = PrioritizedTaskPool(
-            self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
         )
         )
         self.backward_pool = PrioritizedTaskPool(
         self.backward_pool = PrioritizedTaskPool(
-            self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
+            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
         )
         )
 
 
         assert backend_dtype is not None
         assert backend_dtype is not None
         self.dtype = backend_dtype
         self.dtype = backend_dtype
+        self.shard_num_heads = []
+        for shard in self.module.module_shards:
+            for submodule in shard.modules():
+                if isinstance(submodule, BloomAttention):
+                    self.shard_num_heads.append(submodule.num_heads)
+        assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
+
         self.inference_schema = (
         self.inference_schema = (
             (
             (
                 *self.args_schema,
                 *self.args_schema,
@@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend):
             self.kwargs_schema,
             self.kwargs_schema,
         )
         )
 
 
+    def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
+        """Create tensor descriptors for attention cache tensors used during inference_step"""
+        head_dim = self.config.hidden_size // self.config.n_head
+        cache_tensors = []
+        for device, num_heads in zip(self.module.devices, self.shard_num_heads):
+            keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
+            values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
+            cache_tensors.extend((keys, values))
+        return cache_tensors
+
     def inference_step(
     def inference_step(
-        self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
+        self,
+        hidden_states: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
-        num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
         with torch.inference_mode():
         with torch.inference_mode():
             assert (
             assert (
                 hidden_states.ndim == 3
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-            cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
-
-            with self.memory_cache.use_cache(cache_handle) as cache:
-                batch_size = cache.shape[2]
-                max_length = cache.shape[-1] // (head_dim * num_heads)
-                assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
-                if not is_dummy(hypo_ids):
-                    assert hypo_ids.shape[0] == batch_size
-                    cache[rel_index, :, :] = cache[rel_index, :, hypo_ids]  # in-place reorder cache by hypo ids
-                key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
-                value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
-
-                key_past = key_cache.flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
-                value_past = value_cache.flatten(0, 1)[:, :prefix_length, :]  # [batch * num_heads, kv_length, head_dim]
-                logger.debug(
-                    f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
-                )
-                hidden_states, (new_key, new_value) = self.module.forward(
-                    hidden_states, layer_past=(key_past, value_past), use_cache=True
-                )
-                new_length = new_key.shape[-1]
-                assert new_length > prefix_length
-                assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
-                assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
-                new_key = new_key.view(batch_size, num_heads, head_dim, -1)
-                new_value = new_value.view(batch_size, num_heads, -1, head_dim)
-                key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
-                value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
+            with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
+                self._reorder_cache_inplace(cache_tensors, hypo_ids)
+                layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
+                hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+                self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
                 return (hidden_states,)
                 return (hidden_states,)
 
 
+    def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
+        """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
+        if not is_dummy(hypo_ids):
+            for cache_tensor in cache_tensors:
+                cache_tensor[...] = cache_tensor[hypo_ids]  # in-place reorder cache by hypo ids
+
+    def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
+        """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
+        key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
+        for i in range(len(key_cache)):
+            key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
+            value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]  # [batch * num_heads, kv_length, head_dim]
+        layer_past = tuple(chain(*zip(key_cache, value_cache)))
+        return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
+
+    def _update_cache_inplace(
+        self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
+    ):
+        """Writes new key/value tensors back into cache, works in-place"""
+        _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
+        for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
+            new_key = new_key.view(*cache_key.shape[:3], new_length)
+            cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
+        for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
+            new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
+            cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
+
     def get_pools(self) -> Sequence[PrioritizedTaskPool]:
     def get_pools(self) -> Sequence[PrioritizedTaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
         return self.forward_pool, self.backward_pool, self.inference_pool
 
 

+ 20 - 38
src/petals/server/handler.py

@@ -1,5 +1,8 @@
+from __future__ import annotations
+
 import asyncio
 import asyncio
 import contextlib
 import contextlib
+from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
@@ -8,10 +11,10 @@ from hivemind import (
     DHT,
     DHT,
     MSGPackSerializer,
     MSGPackSerializer,
     P2PContext,
     P2PContext,
-    TensorDescriptor,
     deserialize_tensor_stream,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     deserialize_torch_tensor,
     nested_flatten,
     nested_flatten,
+    nested_pack,
     serialize_torch_tensor,
     serialize_torch_tensor,
 )
 )
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
@@ -21,8 +24,9 @@ from hivemind.utils.asyncio import amap_in_executor, anext
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.streaming import split_for_streaming
 
 
-from petals.data_structures import CHAIN_DELIMITER, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
 from petals.server.backend import TransformerBackend
 from petals.server.backend import TransformerBackend
+from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from petals.utils.misc import DUMMY, is_dummy
 from petals.utils.misc import DUMMY, is_dummy
@@ -122,17 +126,12 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
                 point_per_piece = points / max_length if max_length > 0 else 0.0
                 point_per_piece = points / max_length if max_length > 0 else 0.0
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
-
-                cache_metadata = torch.tensor(
-                    [[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
-                )  # [cache_handle, rel_index, prefix_length]
                 prefix_length = 0
                 prefix_length = 0
 
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
+                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                    assert len(cache_handles) == len(requested_backends)
                     while request.tensors:  # iterate while user is willing to supply tensors
                     while request.tensors:  # iterate while user is willing to supply tensors
-                        hidden_states, prompts, hypo_ids = [
-                            deserialize_torch_tensor(tensor) for tensor in request.tensors
-                        ]
+                        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
 
                         # Cast inputs to backend dtype
                         # Cast inputs to backend dtype
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -155,16 +154,14 @@ class TransformerConnectionHandler(ConnectionHandler):
                             )
                             )
 
 
                         # run request tensors through all requested modules, update caches
                         # run request tensors through all requested modules, update caches
-                        for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
+                        for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
                             if not is_dummy(prompt):
                             if not is_dummy(prompt):
                                 hidden_states[:, : prompt.shape[1]] += prompt
                                 hidden_states[:, : prompt.shape[1]] += prompt
                             if hidden_states.numel() == 0:
                             if hidden_states.numel() == 0:
                                 continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
                                 continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
                                 # when user wants to pre-allocate cache or check that server *can* allocate that cache
                                 # when user wants to pre-allocate cache or check that server *can* allocate that cache
 
 
-                            cache_metadata[:] = torch.tensor(
-                                [cache_handle, rel_index, prefix_length], dtype=torch.int64
-                            )
+                            metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
                             assert isinstance(
                             assert isinstance(
                                 hidden_states, torch.Tensor
                                 hidden_states, torch.Tensor
                             ), f"hidden states must be tensor, got {type(hidden_states)}"
                             ), f"hidden states must be tensor, got {type(hidden_states)}"
@@ -175,7 +172,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 backend.inference_pool, PrioritizedTaskPool
                                 backend.inference_pool, PrioritizedTaskPool
                             ), "petals support only prioritized pools"
                             ), "petals support only prioritized pools"
                             priority = self._prioritizer.prioritize(
                             priority = self._prioritizer.prioritize(
-                                cache_metadata,
                                 hidden_states,
                                 hidden_states,
                                 hypo_ids,
                                 hypo_ids,
                                 points=point_per_piece / len(requested_backends),
                                 points=point_per_piece / len(requested_backends),
@@ -183,7 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 type="inference",
                                 type="inference",
                             )
                             )
                             (hidden_states,) = await backend.inference_pool.submit_task(
                             (hidden_states,) = await backend.inference_pool.submit_task(
-                                hidden_states, hypo_ids, cache_metadata, priority=priority
+                                hidden_states, hypo_ids, metadata, priority=priority
                             )
                             )
 
 
                         # serialize and send last layer outputs
                         # serialize and send last layer outputs
@@ -355,28 +351,14 @@ class TransformerConnectionHandler(ConnectionHandler):
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def _allocate_cache(
     async def _allocate_cache(
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
-    ) -> Sequence[int]:
-        """Allocate memory cache for all transformer blocks, return cache handle"""
-
-        n_blocks = len(backends)
-        backend = backends[0]
-        n_heads = backend.module.self_attention.num_heads
-        head_dim = backend.module.self_attention.head_dim
-        descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
-        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
-
-        gib = 1024**3
-        cur_size = backend.memory_cache.current_size_bytes
-        max_size = backend.memory_cache.max_size_bytes
-        friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
-        logger.info(
-            f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
-            f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
-        )
-
-        async with backend.memory_cache.allocate_cache(descr) as handle:
-            logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
-            yield handle
+    ) -> Sequence[Sequence[Handle, ...]]:
+        """
+        Allocate memory cache for all transformer blocks, return cache handle
+        :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
+        """
+        descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
+        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
+            yield nested_pack(handles, descriptors)
 
 
     def _log_request(
     def _log_request(
         self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
         self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None

+ 57 - 36
src/petals/server/memory_cache.py

@@ -10,7 +10,7 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
 import time
 import time
-from typing import AsyncContextManager, Dict, Optional, Union
+from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
 
 
 import hivemind
 import hivemind
 import torch
 import torch
@@ -26,10 +26,9 @@ Handle = int
 class MemoryCache:
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
 
-    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.alloc_timeout = alloc_timeout
         self.alloc_timeout = alloc_timeout
-        self.device = device
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
@@ -57,26 +56,48 @@ class MemoryCache:
         self._handle_counter.value = value
         self._handle_counter.value = value
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
+    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
         """
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
 
-        :param descr: allocate a tensor of this size, dtype, etc
+        :param descriptors: one or more tensors tensor of this size, dtype, etc
+
+        :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
+          if not, it will count maximum tensor allocation across devices for the purposes of size limit
 
 
         :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
         :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
         Furthermore, it can be called concurrently with at most one use_cache call in runtime.
         Furthermore, it can be called concurrently with at most one use_cache call in runtime.
         """
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
-        assert descr.device is None and descr
-
-        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
-        alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
+        assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
+        max_alloc_size = self.get_allocation_size(*descriptors)
+
+        gib = 1024**3
+        cur_size, max_size = self.current_size_bytes, self.max_size_bytes
+        friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+        logger.info(
+            f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), "
+            f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+        )
+
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
         try:
         try:
-            yield await shield_and_wait(alloc_task)
+            handles = await shield_and_wait(alloc_task)
+            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            yield handles
         finally:
         finally:
-            await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
-
-    async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
+            await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
+
+    @staticmethod
+    def get_allocation_size(*descriptors: TensorDescriptor) -> int:
+        """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
+        alloc_size_by_device = {}
+        for descr in descriptors:
+            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
+        return max(alloc_size_by_device.values())
+
+    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
         """
         """
         This method should be called inside asyncio.shield() because:
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
@@ -87,11 +108,11 @@ class MemoryCache:
             if self.current_size_bytes + alloc_size > self.max_size_bytes:
             if self.current_size_bytes + alloc_size > self.max_size_bytes:
                 await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
                 await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
             async with hivemind.utils.enter_asynchronously(self._lock_metadata):
             async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-                handle = int(self.handle_counter)
+                handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
                 self.current_size_bytes += alloc_size
                 self.current_size_bytes += alloc_size
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pipe_send.send((handle, descr))
-                return handle
+                self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
+                self._pipe_send.send((handles, descriptors))
+                return handles
 
 
     async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
     async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
         """
         """
@@ -102,10 +123,10 @@ class MemoryCache:
 
 
         if alloc_task.exception() is not None:
         if alloc_task.exception() is not None:
             return
             return
-        handle = alloc_task.result()
+        handles = alloc_task.result()
 
 
         async with hivemind.utils.enter_asynchronously(self._lock_metadata):
         async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-            self._pipe_send.send((handle, None))  # signal runtime to free that handle
+            self._pipe_send.send((handles, None))  # signal runtime to free these handles
             self.current_size_bytes -= alloc_size
             self.current_size_bytes -= alloc_size
         self._memory_freed_event.set()
         self._memory_freed_event.set()
 
 
@@ -125,11 +146,11 @@ class MemoryCache:
             self._memory_freed_event.clear()
             self._memory_freed_event.clear()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
-    def use_cache(self, handle: Handle) -> torch.Tensor:
+    def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
         """
         """
-        Return a tensor that was previously allocated with try_allocate_cache,
+        Return one or more tensors previously allocated with allocate_cache,
 
 
-        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+        :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism.
         However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
         However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
         """
         """
         assert os.getpid() == self.runtime_pid
         assert os.getpid() == self.runtime_pid
@@ -138,20 +159,20 @@ class MemoryCache:
         with self._lock_metadata:
         with self._lock_metadata:
             # read creation/deletion requests from connection handlers
             # read creation/deletion requests from connection handlers
             while self._pipe_recv.poll():
             while self._pipe_recv.poll():
-                recv_handle, recv_data = self._pipe_recv.recv()
-                if isinstance(recv_data, TensorDescriptor):
-                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
-                elif recv_data is None:
-                    if recv_handle not in self._allocated_tensors:
-                        logger.warning(
-                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
-                        )
-                    self._allocated_tensors.pop(recv_handle, None)
-                else:
-                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
-
-        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
-        yield self._allocated_tensors[handle]
+                recv_handles, recv_data = self._pipe_recv.recv()
+                if recv_data is not None:  # create new tensors
+                    assert len(recv_handles) == len(recv_data)
+                    for handle, descr in zip(recv_handles, recv_data):
+                        self._allocated_tensors[handle] = descr.make_zeros()
+                        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+                else:  # delete tensors by handle
+                    for handle in recv_handles:
+                        if handle not in self._allocated_tensors:
+                            logger.warning(
+                                f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
+                            )
+                        self._allocated_tensors.pop(handle, None)
+        yield tuple(self._allocated_tensors[handle] for handle in handles)
 
 
 
 
 class AllocationFailed(Exception):
 class AllocationFailed(Exception):

+ 42 - 15
src/petals/server/server.py

@@ -6,7 +6,7 @@ import multiprocessing as mp
 import random
 import random
 import threading
 import threading
 import time
 import time
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Sequence, Union
 
 
 import numpy as np
 import numpy as np
 import psutil
 import psutil
@@ -29,7 +29,7 @@ from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.throughput import get_host_throughput
 from petals.server.throughput import get_host_throughput
-from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -76,6 +76,7 @@ class Server:
         mean_block_selection_delay: float = 2.5,
         mean_block_selection_delay: float = 2.5,
         use_auth_token: Optional[str] = None,
         use_auth_token: Optional[str] = None,
         load_in_8bit: Optional[bool] = None,
         load_in_8bit: Optional[bool] = None,
+        tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
         skip_reachability_check: bool = False,
         **kwargs,
         **kwargs,
     ):
     ):
@@ -128,6 +129,8 @@ class Server:
         if device is None:
         if device is None:
             device = "cuda" if torch.cuda.is_available() else "cpu"
             device = "cuda" if torch.cuda.is_available() else "cpu"
         device = torch.device(device)
         device = torch.device(device)
+        if device.type == "cuda" and device.index is None:
+            device = torch.device(device.type, index=0)
         self.device = device
         self.device = device
 
 
         if isinstance(torch_dtype, str):
         if isinstance(torch_dtype, str):
@@ -141,6 +144,13 @@ class Server:
             logger.info("Model weights will be loaded in 8-bit format")
             logger.info("Model weights will be loaded in 8-bit format")
         self.load_in_8bit = load_in_8bit
         self.load_in_8bit = load_in_8bit
 
 
+        if tensor_parallel_devices is None:
+            tensor_parallel_devices = (device,)
+        self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
+        if len(self.tensor_parallel_devices) > 1:
+            logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
+            check_device_balance(self.tensor_parallel_devices)
+
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
         if num_blocks is None and block_indices is None:
             num_blocks = self._choose_num_blocks()
             num_blocks = self._choose_num_blocks()
@@ -174,6 +184,7 @@ class Server:
                 device,
                 device,
                 torch_dtype,
                 torch_dtype,
                 load_in_8bit=load_in_8bit,
                 load_in_8bit=load_in_8bit,
+                tensor_parallel_devices=self.tensor_parallel_devices,
                 force_eval=(throughput == "eval"),
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
                 cache_dir=cache_dir,
             )
             )
@@ -214,13 +225,28 @@ class Server:
             self.converted_model_name_or_path == "bigscience/bloom-petals"
             self.converted_model_name_or_path == "bigscience/bloom-petals"
         ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
         ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
         assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
         assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
+        num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
+
+        if num_devices > 1:
+            memory_per_device = tuple(
+                torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
+            )
+            total_memory = min(memory_per_device) * num_devices
+            if max(memory_per_device) / min(memory_per_device) > 1.5:
+                raise ValueError(
+                    "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
+                    "Please launch individual servers on each GPU or set --num_blocks manually to "
+                    "override this exception."
+                )
+        else:
+            total_memory = torch.cuda.get_device_properties(self.device).total_memory
 
 
-        total_memory = torch.cuda.get_device_properties(self.device).total_memory
         block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
         block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
         gib = 1024**3
         gib = 1024**3
-        attn_cache_per_block = 0.5 * gib  # TODO: This does not account for manually set --attn_cache_size
+        attn_cache_per_block = 0.5 * gib * num_devices  # TODO: This does not account for manually set --attn_cache_size
 
 
-        num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
+        autograd_memory = 2 * gib * num_devices  # gpu memory used for intermediate tensors in rpc_backward
+        num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
 
         logger.info(
         logger.info(
@@ -260,6 +286,7 @@ class Server:
                 sender_threads=self.sender_threads,
                 sender_threads=self.sender_threads,
                 use_auth_token=self.use_auth_token,
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
                 load_in_8bit=self.load_in_8bit,
+                tensor_parallel_devices=self.tensor_parallel_devices,
                 start=True,
                 start=True,
             )
             )
             try:
             try:
@@ -352,6 +379,7 @@ class ModuleContainer(threading.Thread):
         expiration: Optional[float],
         expiration: Optional[float],
         use_auth_token: Optional[str],
         use_auth_token: Optional[str],
         load_in_8bit: bool,
         load_in_8bit: bool,
+        tensor_parallel_devices: Sequence[torch.device],
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@@ -367,7 +395,9 @@ class ModuleContainer(threading.Thread):
         joining_announcer.start()
         joining_announcer.start()
         logger.info(f"Announced that blocks {block_indices} are joining")
         logger.info(f"Announced that blocks {block_indices} are joining")
 
 
-        memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
+        assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
+
+        memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
         blocks = {}
         blocks = {}
         try:
         try:
             for module_uid, block_index in zip(module_uids, block_indices):
             for module_uid, block_index in zip(module_uids, block_indices):
@@ -380,18 +410,13 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                     max_disk_space=max_disk_space,
                 )
                 )
+                block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
 
 
-                if load_in_8bit:
-                    block = replace_8bit_linear(block)
-
-                block = block.to(device)
-                for param in block.parameters():
-                    param.requires_grad = False
-
-                backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
+                backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     module_uid,
                     block,
                     block,
+                    config=block_config,
                     memory_cache=memory_cache,
                     memory_cache=memory_cache,
                     backend_dtype=backend_dtype,
                     backend_dtype=backend_dtype,
                     args_schema=(
                     args_schema=(
@@ -451,6 +476,7 @@ class ModuleContainer(threading.Thread):
         request_timeout: float,
         request_timeout: float,
         session_timeout: float,
         session_timeout: float,
         step_timeout: float,
         step_timeout: float,
+        device: Union[str, torch.device],
         start: bool,
         start: bool,
         **kwargs,
         **kwargs,
     ):
     ):
@@ -469,7 +495,8 @@ class ModuleContainer(threading.Thread):
             )
             )
             for _ in range(num_handlers)
             for _ in range(num_handlers)
         ]
         ]
-        self.runtime = Runtime(self.module_backends, **kwargs)
+        self.runtime = Runtime(self.module_backends, device=None, **kwargs)
+        # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
             list(self.module_backends.keys()),
             dht,
             dht,

+ 18 - 10
src/petals/server/task_pool.py

@@ -5,7 +5,7 @@ import time
 from concurrent.futures._base import PENDING
 from concurrent.futures._base import PENDING
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from queue import PriorityQueue
 from queue import PriorityQueue
-from typing import Any, List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 from hivemind import get_logger
 from hivemind import get_logger
@@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase):
 
 
     :param name: pool name, used for logging
     :param name: pool name, used for logging
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param device: if specified, input tensors will be moved to that device by default
     :param start: if True, start automatically at the end of __init__
     :param start: if True, start automatically at the end of __init__
     """
     """
 
 
@@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase):
         max_batch_size: int,
         max_batch_size: int,
         name: str,
         name: str,
         min_batch_size=1,
         min_batch_size=1,
+        device: Optional[torch.device] = None,
         daemon=True,
         daemon=True,
         start=False,
         start=False,
     ):
     ):
         super().__init__(process_func, daemon=daemon, name=name)
         super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+        self.device = device
 
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
@@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase):
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             self.terminate()
             self.terminate()
 
 
-    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
+    def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
         future = MPFuture()
         future = MPFuture()
         # Remove shmem from MPFuture. This disables the .cancel() feature but
         # Remove shmem from MPFuture. This disables the .cancel() feature but
@@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase):
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
     ) -> Tuple[Any, List[torch.Tensor]]:
     ) -> Tuple[Any, List[torch.Tensor]]:
         """receive next batch of arrays"""
         """receive next batch of arrays"""
+        device = device if device is not None else self.device
         task = self._ordered_tasks.get(block=True, timeout=timeout)
         task = self._ordered_tasks.get(block=True, timeout=timeout)
-        batch_inputs = [
-            tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
-        ]
+        batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
         self._dispatched_tasks[task.uid] = task
         self._dispatched_tasks[task.uid] = task
         self.batch_receiver.recv()  # reduce the number of active batches
         self.batch_receiver.recv()  # reduce the number of active batches
         if not self._ordered_tasks.empty():
         if not self._ordered_tasks.empty():
@@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase):
 
 
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
-        batch_outputs = [
-            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
-            for tensor in batch_outputs
-        ]
-
+        batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
         task = self._dispatched_tasks.pop(uid, None)
         task = self._dispatched_tasks.pop(uid, None)
         if task is None:
         if task is None:
             logger.error(
             logger.error(
@@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase):
         assert len(item) == 2
         assert len(item) == 2
         self._priority.value = float(item[0])
         self._priority.value = float(item[0])
         self._oldest_undispatched_timestamp.value = float(item[1])
         self._oldest_undispatched_timestamp.value = float(item[1])
+
+
+def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False):
+    if isinstance(arg, torch.Tensor):
+        arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad)
+        # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor
+        # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed)
+        if share_memory:
+            arg = arg.share_memory_()
+    return arg

+ 24 - 9
src/petals/server/throughput.py

@@ -2,9 +2,10 @@ import fcntl
 import json
 import json
 import os
 import os
 import time
 import time
+from collections import Counter
 from hashlib import sha256
 from hashlib import sha256
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, Sequence, Union
 
 
 import torch
 import torch
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -12,7 +13,7 @@ from transformers import BloomConfig
 
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
 from petals.server.block_utils import resolve_block_dtype
-from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.convert_block import convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -37,6 +38,7 @@ def get_host_throughput(
     dtype: Union[str, torch.dtype],
     dtype: Union[str, torch.dtype],
     *,
     *,
     load_in_8bit: bool,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
     cache_dir: Optional[str] = None,
 ) -> float:
 ) -> float:
@@ -57,6 +59,9 @@ def get_host_throughput(
         cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
         cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
+        if len(tensor_parallel_devices) > 1:
+            for i, device_i in enumerate(tensor_parallel_devices):
+                cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
 
 
         cache = {}
         cache = {}
         try:
         try:
@@ -69,7 +74,9 @@ def get_host_throughput(
             cache = {}
             cache = {}
 
 
         if cache_key not in cache:
         if cache_key not in cache:
-            cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
+            cache[cache_key] = measure_throughput_info(
+                config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+            )
 
 
             try:
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
                 os.makedirs(cache_path.parent, exist_ok=True)
@@ -87,6 +94,7 @@ def measure_throughput_info(
     dtype: torch.dtype,
     dtype: torch.dtype,
     *,
     *,
     load_in_8bit: bool,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
 ) -> float:
 ) -> float:
     """Measure network and compute throughput in forward pass tokens per second"""
     """Measure network and compute throughput in forward pass tokens per second"""
 
 
@@ -95,7 +103,9 @@ def measure_throughput_info(
     )
     )
     return min(
     return min(
         measure_network_rps(config),
         measure_network_rps(config),
-        measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
+        measure_compute_rps(
+            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+        ),
     )
     )
 
 
 
 
@@ -129,14 +139,15 @@ def measure_compute_rps(
     dtype: torch.dtype,
     dtype: torch.dtype,
     *,
     *,
     load_in_8bit: bool,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
     n_tokens: int = 16,
     n_tokens: int = 16,
     n_steps: int = 500,
     n_steps: int = 500,
 ) -> float:
 ) -> float:
+    if not tensor_parallel_devices:
+        tensor_parallel_devices = (device,)
     with torch.inference_mode():
     with torch.inference_mode():
         block = WrappedBloomBlock(config).to(dtype)
         block = WrappedBloomBlock(config).to(dtype)
-        if load_in_8bit:
-            block = replace_8bit_linear(block)
-        block = block.to(device)
+        block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
 
 
         cache = None
         cache = None
         elapsed = 0
         elapsed = 0
@@ -149,9 +160,13 @@ def measure_compute_rps(
                 elapsed += time.perf_counter() - start_time
                 elapsed += time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
         device_rps = n_steps * n_tokens / elapsed
 
 
+    devices_repr = get_device_name(device)
+    if len(tensor_parallel_devices) > 1:
+        device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))
+        devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
+
     logger.info(
     logger.info(
-        f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
-        f"{device_rps:.1f} RPS"
+        f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
     )
     )
     return device_rps
     return device_rps
 
 

+ 0 - 39
src/petals/utils/convert_8bit.py

@@ -1,39 +0,0 @@
-import bitsandbytes as bnb
-import torch
-
-from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
-
-
-def replace_8bit_linear(model, threshold=6.0):
-    """
-    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
-    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
-    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
-    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
-    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
-    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
-    be kept as a `torch.nn.Linear` module.
-    Parameters:
-        model (`torch.nn.Module`):
-            Input model or `torch.nn.Module` as the function is run recursively.
-        threshold (`float`, *optional*):
-            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
-            `6.0` as described by the paper.
-    """
-    for n, module in model.named_children():
-        if len(list(module.children())) > 0:
-            replace_8bit_linear(module, threshold)
-
-        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
-            model._modules[n] = CustomLinear8bitLt(
-                module.in_features,
-                module.out_features,
-                module.bias is not None,
-                has_fp16_weights=False,
-                threshold=threshold,
-            )
-            model._modules[n].weight = bnb.nn.Int8Params(
-                module.weight.data, requires_grad=False, has_fp16_weights=False
-            ).to(module.weight.dtype)
-            model._modules[n].bias = module.bias
-    return model

+ 132 - 0
src/petals/utils/convert_block.py

@@ -0,0 +1,132 @@
+"""
+Tools for converting transformer blocks, applying quantization and/or tensor parallelism
+"""
+import re
+from typing import Sequence
+
+import bitsandbytes as bnb
+import tensor_parallel as tp
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from tensor_parallel.slicing_configs import get_bloom_config
+from transformers import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
+
+from petals.bloom.block import WrappedBloomBlock
+from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def convert_block(
+    block: WrappedBloomBlock,
+    config: BloomConfig,
+    tensor_parallel_devices: Sequence[torch.device],
+    output_device: torch.device,
+    load_in_8bit: bool,
+    threshold: float = 6.0,
+    freeze: bool = True,
+) -> tp.TensorParallel:
+    """
+    Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
+
+    :note: some optimizations will modify the input block in-place!
+    :param block: a single transformer block, either pre-trained or newly initialized
+    :param config: HF transformers config for the full model
+    :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
+    :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
+    :param output_device: if tensor_parallel_devices is True, output
+    :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
+    :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
+    :param freeze: if True (default), make all module parameters non-trainable
+    :return: a module that acts like the original block, but runs with all specified optimizations
+
+    """
+    if freeze:
+        for param in block.parameters():
+            param.requires_grad = False
+
+    block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
+
+    if load_in_8bit:
+        block = replace_8bit_linear(block, threshold=threshold)
+
+    for shard, device in zip(block.module_shards, block.devices):
+        shard.to(device)
+
+    return block
+
+
+def replace_8bit_linear(model: nn.Module, threshold=6.0):
+    """
+    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
+    be kept as a `torch.nn.Linear` module.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
+            assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
+            model._modules[n] = CustomLinear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+            )
+            model._modules[n].weight = bnb.nn.Int8Params(
+                module.weight.data, requires_grad=False, has_fp16_weights=False
+            ).to(module.weight.dtype)
+            model._modules[n].bias = module.bias
+    return model
+
+
+def make_tensor_parallel(
+    block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
+):
+    assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
+    tp_config = get_bloom_config(model_config, devices)
+    del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
+    total_heads = 0
+    for tp_shard in tp_block.module_shards:
+        for submodule in tp_shard.modules():
+            if isinstance(submodule, BloomAttention):
+                total_heads += submodule.num_heads
+    assert total_heads == model_config.n_head
+    return tp_block
+
+
+def check_device_balance(devices: Sequence[torch.device]):
+    if not all(device.type == "cuda" for device in devices):
+        logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
+        return
+    unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
+    if len(unique_device_capabilities) > 1:
+        logger.warning(
+            f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
+            f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
+        )
+
+    memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
+    used_memory = min(memory_per_device) * len(memory_per_device)
+    wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
+    if wasted_memory_rate > 0.05:
+        logger.warning(
+            f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
+            f"Consider running high-memory GPUs in a separate server."
+        )

+ 9 - 2
tests/test_aux_functions.py

@@ -7,10 +7,17 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_throughput_basic():
+@pytest.mark.parametrize("tensor_parallel", [False, True])
+def test_throughput_basic(tensor_parallel: bool):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
     compute_rps = measure_compute_rps(
-        config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
+        config,
+        device=torch.device("cpu"),
+        dtype=torch.bfloat16,
+        load_in_8bit=False,
+        tensor_parallel_devices=tensor_parallel_devices,
+        n_steps=10,
     )
     )
     assert isinstance(compute_rps, float) and compute_rps > 0
     assert isinstance(compute_rps, float) and compute_rps > 0
     network_rps = measure_network_rps(config)
     network_rps = measure_network_rps(config)

+ 1 - 1
tests/test_block_exact_match.py

@@ -13,7 +13,7 @@ from petals.dht_utils import get_remote_module
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
 
 

+ 11 - 10
tests/test_remote_sequential.py

@@ -1,6 +1,7 @@
 import pytest
 import pytest
 import torch
 import torch
-from hivemind import DHT, BatchTensorDescriptor, get_logger
+import torch.nn.functional as F
+from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from test_utils import *
 from test_utils import *
 
 
@@ -39,10 +40,10 @@ def test_remote_sequential():
     assert hidden.shape == test_inputs.shape
     assert hidden.shape == test_inputs.shape
     assert hidden.requires_grad
     assert hidden.requires_grad
     second_half_outputs = second_half(hidden)
     second_half_outputs = second_half(hidden)
-    assert torch.allclose(second_half_outputs, full_outputs)
+    assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
 
 
     (second_half_outputs * grad_proj).sum().backward()
     (second_half_outputs * grad_proj).sum().backward()
-    assert torch.allclose(test_inputs.grad, full_grad)
+    assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
 
 
     # test RemoteSequential with lossy compression
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
@@ -58,7 +59,7 @@ def test_remote_sequential():
     assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
     assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
     assert abs(approx_outputs - full_outputs).mean() < 0.01
     assert abs(approx_outputs - full_outputs).mean() < 0.01
     absmax = abs(full_grad).max()
     absmax = abs(full_grad).max()
-    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
+    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
 
 
 
 
 class DummyCustomSequenceManager(RemoteSequenceManager):
 class DummyCustomSequenceManager(RemoteSequenceManager):
@@ -87,9 +88,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     remote_sequential = RemoteSequential(config, dht)
     remote_sequential = RemoteSequential(config, dht)
 
 
-    inputs = torch.randn(batch_size, seq_len, config.hidden_size)
-    output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
-    input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
+    output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
+    input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
     intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
     intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
 
 
     input_prompts = input_prompts.detach().requires_grad_(True)
     input_prompts = input_prompts.detach().requires_grad_(True)
@@ -117,10 +118,10 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
         block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
         block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
         (outputs_ref,) = block(outputs_ref)
         (outputs_ref,) = block(outputs_ref)
 
 
-    assert torch.allclose(outputs_ref, outputs)
+    assert torch.allclose(outputs_ref, outputs, atol=1e-3)
 
 
     (outputs_ref * output_proj).sum().backward()
     (outputs_ref * output_proj).sum().backward()
     assert input_prompts_ref.grad is not None
     assert input_prompts_ref.grad is not None
-    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
     assert intermediate_prompts_ref.grad is not None
     assert intermediate_prompts_ref.grad is not None
-    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
+    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

+ 46 - 0
tests/test_tensor_parallel.py

@@ -0,0 +1,46 @@
+import random
+
+import pytest
+import torch
+import transformers
+from tensor_parallel import TensorParallel
+from tensor_parallel.slicing_configs import get_bloom_config
+from test_utils import MODEL_NAME
+
+from petals.bloom.from_pretrained import load_pretrained_block
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("custom_config", [True, False])
+@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
+def test_tp_block(devices, custom_config):
+    block_index = random.randint(0, 10)
+    model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
+
+    tp_config = None
+    if custom_config:
+        tp_config = get_bloom_config(model_config, devices)
+
+    batch_size = 2
+    prefix_length = 5
+
+    test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
+    test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
+    test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
+    test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
+    grad_proj = torch.rand_like(test_inputs1)
+
+    y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
+    y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
+    y_ref.backward(grad_proj)
+
+    block_tp = TensorParallel(block, devices, config=tp_config)
+    y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
+    y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
+    y_ours.backward(grad_proj)
+
+    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
+    assert torch.allclose(y_ours, y_ref, atol=1e-6)
+    assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
+    assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)