Forráskód Böngészése

Add local tensor-parallel fwd/bwd (#143)

This pull request adds an option to run Petals server on multiple local GPUs. It uses https://github.com/BlackSamorez/tensor_parallel

- 8bit approximation error same as in main (mean~=2% q0.9~=5%)
    - TP=1, 2, 3 (see screenshots above)
- forward, grad w.r.t. input and inference exact match with main with TP=1
- `>=`80% GPU utilization with 3x 1080ti, batch = 8 tokens
- throughput measured with and without TP
- TP on 1080Tis has near-linear speedup comparable to the benchmarks (see first message)


Co-authored-by: Iaroslav Lisniak <yalisnyak@nes.ru>
Co-authored-by: Andrei Panferov <andrei@blacksamorez.ru>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 éve
szülő
commit
ae9e71fe8e

+ 3 - 3
.github/workflows/run-tests.yaml

@@ -86,16 +86,16 @@ jobs:
 
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
           SERVER3_PID=$!
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
           SERVER4_PID=$!
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu  --torch_dtype float32 &> server5.log &
           SERVER5_PID=$!
 
           tail -n 100 -f server*.log &

+ 1 - 0
setup.cfg

@@ -39,6 +39,7 @@ install_requires =
     protobuf>=3.20.3,<4.0dev
     speedtest-cli==2.1.3
     hivemind==1.1.3
+    tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2
 

+ 5 - 1
src/petals/cli/run_server.py

@@ -129,8 +129,12 @@ def main():
 
     parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
     parser.add_argument('--load_in_8bit', type=str, default=None,
-                        help="Convert the loaded model into mixed-8bit quantized model. "
+                        help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
                              "Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
+    parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
+                        help=
+                        "Split each block between the specified GPUs such that each device holds a portion of every "
+                        "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
 
     parser.add_argument("--skip_reachability_check", action='store_true',
                         help="Skip checking this server's reachability via health.petals.ml "

+ 12 - 1
src/petals/data_structures.py

@@ -1,9 +1,14 @@
+from __future__ import annotations
+
+import dataclasses
 from dataclasses import dataclass
 from enum import Enum
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 
 from hivemind import PeerID
 
+from petals.server.memory_cache import Handle
+
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@@ -39,3 +44,9 @@ class RemoteSpanInfo:
 
 
 RPCInfo = Dict[str, Any]
+
+
+@dataclasses.dataclass(frozen=True)
+class InferenceMetadata:
+    prefix_length: int
+    cache_handles: Tuple[Handle, ...]

+ 69 - 37
src/petals/server/backend.py

@@ -1,12 +1,19 @@
 """Code for serving bloom blocks via hivemind-server"""
+from __future__ import annotations
+
+from itertools import chain
 from typing import Any, Dict, Sequence, Tuple
 
 import torch
-from hivemind import BatchTensorDescriptor
+from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
+from tensor_parallel import TensorParallel
+from tensor_parallel.tensor_parallel import PerDeviceTensors
+from transformers import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
 
-from petals.bloom.block import WrappedBloomBlock
+from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
@@ -17,9 +24,10 @@ logger = get_logger(__file__)
 class TransformerBackend(ModuleBackend):
     """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
 
-    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
+    def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
         super().__init__(*args, **kwargs)
-        assert isinstance(self.module, WrappedBloomBlock)
+        assert isinstance(self.module, TensorParallel)
+        self.config = config
         self.memory_cache = memory_cache
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
@@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend):
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
         max_batch_size = self.forward_pool.max_batch_size
+        device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
+            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
         )
         self.forward_pool = PrioritizedTaskPool(
-            self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
         )
         self.backward_pool = PrioritizedTaskPool(
-            self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
+            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
         )
 
         assert backend_dtype is not None
         self.dtype = backend_dtype
+        self.shard_num_heads = []
+        for shard in self.module.module_shards:
+            for submodule in shard.modules():
+                if isinstance(submodule, BloomAttention):
+                    self.shard_num_heads.append(submodule.num_heads)
+        assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
+
         self.inference_schema = (
             (
                 *self.args_schema,
@@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend):
             self.kwargs_schema,
         )
 
+    def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
+        """Create tensor descriptors for attention cache tensors used during inference_step"""
+        head_dim = self.config.hidden_size // self.config.n_head
+        cache_tensors = []
+        for device, num_heads in zip(self.module.devices, self.shard_num_heads):
+            keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
+            values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
+            cache_tensors.extend((keys, values))
+        return cache_tensors
+
     def inference_step(
-        self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
+        self,
+        hidden_states: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
-        num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
         with torch.inference_mode():
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-            cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
-
-            with self.memory_cache.use_cache(cache_handle) as cache:
-                batch_size = cache.shape[2]
-                max_length = cache.shape[-1] // (head_dim * num_heads)
-                assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
-                if not is_dummy(hypo_ids):
-                    assert hypo_ids.shape[0] == batch_size
-                    cache[rel_index, :, :] = cache[rel_index, :, hypo_ids]  # in-place reorder cache by hypo ids
-                key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
-                value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
-
-                key_past = key_cache.flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
-                value_past = value_cache.flatten(0, 1)[:, :prefix_length, :]  # [batch * num_heads, kv_length, head_dim]
-                logger.debug(
-                    f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
-                )
-                hidden_states, (new_key, new_value) = self.module.forward(
-                    hidden_states, layer_past=(key_past, value_past), use_cache=True
-                )
-                new_length = new_key.shape[-1]
-                assert new_length > prefix_length
-                assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
-                assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
-                new_key = new_key.view(batch_size, num_heads, head_dim, -1)
-                new_value = new_value.view(batch_size, num_heads, -1, head_dim)
-                key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
-                value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
+            with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
+                self._reorder_cache_inplace(cache_tensors, hypo_ids)
+                layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
+                hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+                self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
                 return (hidden_states,)
 
+    def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
+        """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
+        if not is_dummy(hypo_ids):
+            for cache_tensor in cache_tensors:
+                cache_tensor[...] = cache_tensor[hypo_ids]  # in-place reorder cache by hypo ids
+
+    def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
+        """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
+        key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
+        for i in range(len(key_cache)):
+            key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
+            value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]  # [batch * num_heads, kv_length, head_dim]
+        layer_past = tuple(chain(*zip(key_cache, value_cache)))
+        return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
+
+    def _update_cache_inplace(
+        self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
+    ):
+        """Writes new key/value tensors back into cache, works in-place"""
+        _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
+        for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
+            new_key = new_key.view(*cache_key.shape[:3], new_length)
+            cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
+        for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
+            new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
+            cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
+
     def get_pools(self) -> Sequence[PrioritizedTaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
 

+ 20 - 38
src/petals/server/handler.py

@@ -1,5 +1,8 @@
+from __future__ import annotations
+
 import asyncio
 import contextlib
+from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
@@ -8,10 +11,10 @@ from hivemind import (
     DHT,
     MSGPackSerializer,
     P2PContext,
-    TensorDescriptor,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     nested_flatten,
+    nested_pack,
     serialize_torch_tensor,
 )
 from hivemind.moe.server.connection_handler import ConnectionHandler
@@ -21,8 +24,9 @@ from hivemind.utils.asyncio import amap_in_executor, anext
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
-from petals.data_structures import CHAIN_DELIMITER, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
 from petals.server.backend import TransformerBackend
+from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from petals.utils.misc import DUMMY, is_dummy
@@ -122,17 +126,12 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 point_per_piece = points / max_length if max_length > 0 else 0.0
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
-
-                cache_metadata = torch.tensor(
-                    [[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
-                )  # [cache_handle, rel_index, prefix_length]
                 prefix_length = 0
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
+                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                    assert len(cache_handles) == len(requested_backends)
                     while request.tensors:  # iterate while user is willing to supply tensors
-                        hidden_states, prompts, hypo_ids = [
-                            deserialize_torch_tensor(tensor) for tensor in request.tensors
-                        ]
+                        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                         # Cast inputs to backend dtype
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -155,16 +154,14 @@ class TransformerConnectionHandler(ConnectionHandler):
                             )
 
                         # run request tensors through all requested modules, update caches
-                        for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
+                        for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
                             if not is_dummy(prompt):
                                 hidden_states[:, : prompt.shape[1]] += prompt
                             if hidden_states.numel() == 0:
                                 continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
                                 # when user wants to pre-allocate cache or check that server *can* allocate that cache
 
-                            cache_metadata[:] = torch.tensor(
-                                [cache_handle, rel_index, prefix_length], dtype=torch.int64
-                            )
+                            metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
                             assert isinstance(
                                 hidden_states, torch.Tensor
                             ), f"hidden states must be tensor, got {type(hidden_states)}"
@@ -175,7 +172,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 backend.inference_pool, PrioritizedTaskPool
                             ), "petals support only prioritized pools"
                             priority = self._prioritizer.prioritize(
-                                cache_metadata,
                                 hidden_states,
                                 hypo_ids,
                                 points=point_per_piece / len(requested_backends),
@@ -183,7 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 type="inference",
                             )
                             (hidden_states,) = await backend.inference_pool.submit_task(
-                                hidden_states, hypo_ids, cache_metadata, priority=priority
+                                hidden_states, hypo_ids, metadata, priority=priority
                             )
 
                         # serialize and send last layer outputs
@@ -355,28 +351,14 @@ class TransformerConnectionHandler(ConnectionHandler):
     @contextlib.asynccontextmanager
     async def _allocate_cache(
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
-    ) -> Sequence[int]:
-        """Allocate memory cache for all transformer blocks, return cache handle"""
-
-        n_blocks = len(backends)
-        backend = backends[0]
-        n_heads = backend.module.self_attention.num_heads
-        head_dim = backend.module.self_attention.head_dim
-        descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
-        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
-
-        gib = 1024**3
-        cur_size = backend.memory_cache.current_size_bytes
-        max_size = backend.memory_cache.max_size_bytes
-        friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
-        logger.info(
-            f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
-            f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
-        )
-
-        async with backend.memory_cache.allocate_cache(descr) as handle:
-            logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
-            yield handle
+    ) -> Sequence[Sequence[Handle, ...]]:
+        """
+        Allocate memory cache for all transformer blocks, return cache handle
+        :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
+        """
+        descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
+        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
+            yield nested_pack(handles, descriptors)
 
     def _log_request(
         self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None

+ 57 - 36
src/petals/server/memory_cache.py

@@ -10,7 +10,7 @@ import ctypes
 import multiprocessing as mp
 import os
 import time
-from typing import AsyncContextManager, Dict, Optional, Union
+from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
 
 import hivemind
 import torch
@@ -26,10 +26,9 @@ Handle = int
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
-    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.alloc_timeout = alloc_timeout
-        self.device = device
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
@@ -57,26 +56,48 @@ class MemoryCache:
         self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
+    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
-        :param descr: allocate a tensor of this size, dtype, etc
+        :param descriptors: one or more tensors tensor of this size, dtype, etc
+
+        :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
+          if not, it will count maximum tensor allocation across devices for the purposes of size limit
 
         :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
         Furthermore, it can be called concurrently with at most one use_cache call in runtime.
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
-        assert descr.device is None and descr
-
-        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
-        alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
+        assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
+        max_alloc_size = self.get_allocation_size(*descriptors)
+
+        gib = 1024**3
+        cur_size, max_size = self.current_size_bytes, self.max_size_bytes
+        friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+        logger.info(
+            f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), "
+            f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+        )
+
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
         try:
-            yield await shield_and_wait(alloc_task)
+            handles = await shield_and_wait(alloc_task)
+            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            yield handles
         finally:
-            await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
-
-    async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
+            await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
+
+    @staticmethod
+    def get_allocation_size(*descriptors: TensorDescriptor) -> int:
+        """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
+        alloc_size_by_device = {}
+        for descr in descriptors:
+            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
+        return max(alloc_size_by_device.values())
+
+    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
         """
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
@@ -87,11 +108,11 @@ class MemoryCache:
             if self.current_size_bytes + alloc_size > self.max_size_bytes:
                 await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
             async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-                handle = int(self.handle_counter)
+                handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
                 self.current_size_bytes += alloc_size
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pipe_send.send((handle, descr))
-                return handle
+                self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
+                self._pipe_send.send((handles, descriptors))
+                return handles
 
     async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
         """
@@ -102,10 +123,10 @@ class MemoryCache:
 
         if alloc_task.exception() is not None:
             return
-        handle = alloc_task.result()
+        handles = alloc_task.result()
 
         async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-            self._pipe_send.send((handle, None))  # signal runtime to free that handle
+            self._pipe_send.send((handles, None))  # signal runtime to free these handles
             self.current_size_bytes -= alloc_size
         self._memory_freed_event.set()
 
@@ -125,11 +146,11 @@ class MemoryCache:
             self._memory_freed_event.clear()
 
     @contextlib.contextmanager
-    def use_cache(self, handle: Handle) -> torch.Tensor:
+    def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
         """
-        Return a tensor that was previously allocated with try_allocate_cache,
+        Return one or more tensors previously allocated with allocate_cache,
 
-        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+        :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism.
         However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
         """
         assert os.getpid() == self.runtime_pid
@@ -138,20 +159,20 @@ class MemoryCache:
         with self._lock_metadata:
             # read creation/deletion requests from connection handlers
             while self._pipe_recv.poll():
-                recv_handle, recv_data = self._pipe_recv.recv()
-                if isinstance(recv_data, TensorDescriptor):
-                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
-                elif recv_data is None:
-                    if recv_handle not in self._allocated_tensors:
-                        logger.warning(
-                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
-                        )
-                    self._allocated_tensors.pop(recv_handle, None)
-                else:
-                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
-
-        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
-        yield self._allocated_tensors[handle]
+                recv_handles, recv_data = self._pipe_recv.recv()
+                if recv_data is not None:  # create new tensors
+                    assert len(recv_handles) == len(recv_data)
+                    for handle, descr in zip(recv_handles, recv_data):
+                        self._allocated_tensors[handle] = descr.make_zeros()
+                        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+                else:  # delete tensors by handle
+                    for handle in recv_handles:
+                        if handle not in self._allocated_tensors:
+                            logger.warning(
+                                f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
+                            )
+                        self._allocated_tensors.pop(handle, None)
+        yield tuple(self._allocated_tensors[handle] for handle in handles)
 
 
 class AllocationFailed(Exception):

+ 42 - 15
src/petals/server/server.py

@@ -6,7 +6,7 @@ import multiprocessing as mp
 import random
 import threading
 import time
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Sequence, Union
 
 import numpy as np
 import psutil
@@ -29,7 +29,7 @@ from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.throughput import get_host_throughput
-from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 logger = get_logger(__file__)
@@ -76,6 +76,7 @@ class Server:
         mean_block_selection_delay: float = 2.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: Optional[bool] = None,
+        tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
         **kwargs,
     ):
@@ -128,6 +129,8 @@ class Server:
         if device is None:
             device = "cuda" if torch.cuda.is_available() else "cpu"
         device = torch.device(device)
+        if device.type == "cuda" and device.index is None:
+            device = torch.device(device.type, index=0)
         self.device = device
 
         if isinstance(torch_dtype, str):
@@ -141,6 +144,13 @@ class Server:
             logger.info("Model weights will be loaded in 8-bit format")
         self.load_in_8bit = load_in_8bit
 
+        if tensor_parallel_devices is None:
+            tensor_parallel_devices = (device,)
+        self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
+        if len(self.tensor_parallel_devices) > 1:
+            logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
+            check_device_balance(self.tensor_parallel_devices)
+
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
             num_blocks = self._choose_num_blocks()
@@ -174,6 +184,7 @@ class Server:
                 device,
                 torch_dtype,
                 load_in_8bit=load_in_8bit,
+                tensor_parallel_devices=self.tensor_parallel_devices,
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
             )
@@ -214,13 +225,28 @@ class Server:
             self.converted_model_name_or_path == "bigscience/bloom-petals"
         ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
         assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
+        num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
+
+        if num_devices > 1:
+            memory_per_device = tuple(
+                torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
+            )
+            total_memory = min(memory_per_device) * num_devices
+            if max(memory_per_device) / min(memory_per_device) > 1.5:
+                raise ValueError(
+                    "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
+                    "Please launch individual servers on each GPU or set --num_blocks manually to "
+                    "override this exception."
+                )
+        else:
+            total_memory = torch.cuda.get_device_properties(self.device).total_memory
 
-        total_memory = torch.cuda.get_device_properties(self.device).total_memory
         block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
         gib = 1024**3
-        attn_cache_per_block = 0.5 * gib  # TODO: This does not account for manually set --attn_cache_size
+        attn_cache_per_block = 0.5 * gib * num_devices  # TODO: This does not account for manually set --attn_cache_size
 
-        num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
+        autograd_memory = 2 * gib * num_devices  # gpu memory used for intermediate tensors in rpc_backward
+        num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
         logger.info(
@@ -260,6 +286,7 @@ class Server:
                 sender_threads=self.sender_threads,
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
+                tensor_parallel_devices=self.tensor_parallel_devices,
                 start=True,
             )
             try:
@@ -352,6 +379,7 @@ class ModuleContainer(threading.Thread):
         expiration: Optional[float],
         use_auth_token: Optional[str],
         load_in_8bit: bool,
+        tensor_parallel_devices: Sequence[torch.device],
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@@ -367,7 +395,9 @@ class ModuleContainer(threading.Thread):
         joining_announcer.start()
         logger.info(f"Announced that blocks {block_indices} are joining")
 
-        memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
+        assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
+
+        memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
         blocks = {}
         try:
             for module_uid, block_index in zip(module_uids, block_indices):
@@ -380,18 +410,13 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
+                block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
 
-                if load_in_8bit:
-                    block = replace_8bit_linear(block)
-
-                block = block.to(device)
-                for param in block.parameters():
-                    param.requires_grad = False
-
-                backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
+                backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
+                    config=block_config,
                     memory_cache=memory_cache,
                     backend_dtype=backend_dtype,
                     args_schema=(
@@ -451,6 +476,7 @@ class ModuleContainer(threading.Thread):
         request_timeout: float,
         session_timeout: float,
         step_timeout: float,
+        device: Union[str, torch.device],
         start: bool,
         **kwargs,
     ):
@@ -469,7 +495,8 @@ class ModuleContainer(threading.Thread):
             )
             for _ in range(num_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, **kwargs)
+        self.runtime = Runtime(self.module_backends, device=None, **kwargs)
+        # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
             dht,

+ 18 - 10
src/petals/server/task_pool.py

@@ -5,7 +5,7 @@ import time
 from concurrent.futures._base import PENDING
 from dataclasses import dataclass, field
 from queue import PriorityQueue
-from typing import Any, List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple, Union
 
 import torch
 from hivemind import get_logger
@@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase):
 
     :param name: pool name, used for logging
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param device: if specified, input tensors will be moved to that device by default
     :param start: if True, start automatically at the end of __init__
     """
 
@@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase):
         max_batch_size: int,
         name: str,
         min_batch_size=1,
+        device: Optional[torch.device] = None,
         daemon=True,
         start=False,
     ):
         super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+        self.device = device
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
@@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase):
             logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
             self.terminate()
 
-    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
+    def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         future = MPFuture()
         # Remove shmem from MPFuture. This disables the .cancel() feature but
@@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase):
         self, timeout: Optional[float] = None, device: Optional[torch.device] = None
     ) -> Tuple[Any, List[torch.Tensor]]:
         """receive next batch of arrays"""
+        device = device if device is not None else self.device
         task = self._ordered_tasks.get(block=True, timeout=timeout)
-        batch_inputs = [
-            tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
-        ]
+        batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
         self._dispatched_tasks[task.uid] = task
         self.batch_receiver.recv()  # reduce the number of active batches
         if not self._ordered_tasks.empty():
@@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase):
 
     def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
-        batch_outputs = [
-            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
-            for tensor in batch_outputs
-        ]
-
+        batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
         task = self._dispatched_tasks.pop(uid, None)
         if task is None:
             logger.error(
@@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase):
         assert len(item) == 2
         self._priority.value = float(item[0])
         self._oldest_undispatched_timestamp.value = float(item[1])
+
+
+def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False):
+    if isinstance(arg, torch.Tensor):
+        arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad)
+        # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor
+        # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed)
+        if share_memory:
+            arg = arg.share_memory_()
+    return arg

+ 24 - 9
src/petals/server/throughput.py

@@ -2,9 +2,10 @@ import fcntl
 import json
 import os
 import time
+from collections import Counter
 from hashlib import sha256
 from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, Sequence, Union
 
 import torch
 from hivemind.utils.logging import get_logger
@@ -12,7 +13,7 @@ from transformers import BloomConfig
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
-from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.convert_block import convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 logger = get_logger(__file__)
@@ -37,6 +38,7 @@ def get_host_throughput(
     dtype: Union[str, torch.dtype],
     *,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
 ) -> float:
@@ -57,6 +59,9 @@ def get_host_throughput(
         cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
+        if len(tensor_parallel_devices) > 1:
+            for i, device_i in enumerate(tensor_parallel_devices):
+                cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
 
         cache = {}
         try:
@@ -69,7 +74,9 @@ def get_host_throughput(
             cache = {}
 
         if cache_key not in cache:
-            cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
+            cache[cache_key] = measure_throughput_info(
+                config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+            )
 
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
@@ -87,6 +94,7 @@ def measure_throughput_info(
     dtype: torch.dtype,
     *,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
 ) -> float:
     """Measure network and compute throughput in forward pass tokens per second"""
 
@@ -95,7 +103,9 @@ def measure_throughput_info(
     )
     return min(
         measure_network_rps(config),
-        measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
+        measure_compute_rps(
+            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+        ),
     )
 
 
@@ -129,14 +139,15 @@ def measure_compute_rps(
     dtype: torch.dtype,
     *,
     load_in_8bit: bool,
+    tensor_parallel_devices: Sequence[torch.device],
     n_tokens: int = 16,
     n_steps: int = 500,
 ) -> float:
+    if not tensor_parallel_devices:
+        tensor_parallel_devices = (device,)
     with torch.inference_mode():
         block = WrappedBloomBlock(config).to(dtype)
-        if load_in_8bit:
-            block = replace_8bit_linear(block)
-        block = block.to(device)
+        block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
 
         cache = None
         elapsed = 0
@@ -149,9 +160,13 @@ def measure_compute_rps(
                 elapsed += time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
 
+    devices_repr = get_device_name(device)
+    if len(tensor_parallel_devices) > 1:
+        device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))
+        devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
+
     logger.info(
-        f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
-        f"{device_rps:.1f} RPS"
+        f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
     )
     return device_rps
 

+ 0 - 39
src/petals/utils/convert_8bit.py

@@ -1,39 +0,0 @@
-import bitsandbytes as bnb
-import torch
-
-from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
-
-
-def replace_8bit_linear(model, threshold=6.0):
-    """
-    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
-    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
-    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
-    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
-    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
-    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
-    be kept as a `torch.nn.Linear` module.
-    Parameters:
-        model (`torch.nn.Module`):
-            Input model or `torch.nn.Module` as the function is run recursively.
-        threshold (`float`, *optional*):
-            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
-            `6.0` as described by the paper.
-    """
-    for n, module in model.named_children():
-        if len(list(module.children())) > 0:
-            replace_8bit_linear(module, threshold)
-
-        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
-            model._modules[n] = CustomLinear8bitLt(
-                module.in_features,
-                module.out_features,
-                module.bias is not None,
-                has_fp16_weights=False,
-                threshold=threshold,
-            )
-            model._modules[n].weight = bnb.nn.Int8Params(
-                module.weight.data, requires_grad=False, has_fp16_weights=False
-            ).to(module.weight.dtype)
-            model._modules[n].bias = module.bias
-    return model

+ 132 - 0
src/petals/utils/convert_block.py

@@ -0,0 +1,132 @@
+"""
+Tools for converting transformer blocks, applying quantization and/or tensor parallelism
+"""
+import re
+from typing import Sequence
+
+import bitsandbytes as bnb
+import tensor_parallel as tp
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from tensor_parallel.slicing_configs import get_bloom_config
+from transformers import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
+
+from petals.bloom.block import WrappedBloomBlock
+from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def convert_block(
+    block: WrappedBloomBlock,
+    config: BloomConfig,
+    tensor_parallel_devices: Sequence[torch.device],
+    output_device: torch.device,
+    load_in_8bit: bool,
+    threshold: float = 6.0,
+    freeze: bool = True,
+) -> tp.TensorParallel:
+    """
+    Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
+
+    :note: some optimizations will modify the input block in-place!
+    :param block: a single transformer block, either pre-trained or newly initialized
+    :param config: HF transformers config for the full model
+    :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
+    :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
+    :param output_device: if tensor_parallel_devices is True, output
+    :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
+    :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
+    :param freeze: if True (default), make all module parameters non-trainable
+    :return: a module that acts like the original block, but runs with all specified optimizations
+
+    """
+    if freeze:
+        for param in block.parameters():
+            param.requires_grad = False
+
+    block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
+
+    if load_in_8bit:
+        block = replace_8bit_linear(block, threshold=threshold)
+
+    for shard, device in zip(block.module_shards, block.devices):
+        shard.to(device)
+
+    return block
+
+
+def replace_8bit_linear(model: nn.Module, threshold=6.0):
+    """
+    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
+    be kept as a `torch.nn.Linear` module.
+    Parameters:
+        model (`torch.nn.Module`):
+            Input model or `torch.nn.Module` as the function is run recursively.
+        threshold (`float`, *optional*):
+            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+            `6.0` as described by the paper.
+    """
+    for n, module in model.named_children():
+        if len(list(module.children())) > 0:
+            replace_8bit_linear(module, threshold)
+
+        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
+            assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
+            model._modules[n] = CustomLinear8bitLt(
+                module.in_features,
+                module.out_features,
+                module.bias is not None,
+                has_fp16_weights=False,
+                threshold=threshold,
+            )
+            model._modules[n].weight = bnb.nn.Int8Params(
+                module.weight.data, requires_grad=False, has_fp16_weights=False
+            ).to(module.weight.dtype)
+            model._modules[n].bias = module.bias
+    return model
+
+
+def make_tensor_parallel(
+    block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
+):
+    assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
+    tp_config = get_bloom_config(model_config, devices)
+    del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
+    total_heads = 0
+    for tp_shard in tp_block.module_shards:
+        for submodule in tp_shard.modules():
+            if isinstance(submodule, BloomAttention):
+                total_heads += submodule.num_heads
+    assert total_heads == model_config.n_head
+    return tp_block
+
+
+def check_device_balance(devices: Sequence[torch.device]):
+    if not all(device.type == "cuda" for device in devices):
+        logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
+        return
+    unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
+    if len(unique_device_capabilities) > 1:
+        logger.warning(
+            f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
+            f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
+        )
+
+    memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
+    used_memory = min(memory_per_device) * len(memory_per_device)
+    wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
+    if wasted_memory_rate > 0.05:
+        logger.warning(
+            f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
+            f"Consider running high-memory GPUs in a separate server."
+        )

+ 9 - 2
tests/test_aux_functions.py

@@ -7,10 +7,17 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps
 
 
 @pytest.mark.forked
-def test_throughput_basic():
+@pytest.mark.parametrize("tensor_parallel", [False, True])
+def test_throughput_basic(tensor_parallel: bool):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
-        config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
+        config,
+        device=torch.device("cpu"),
+        dtype=torch.bfloat16,
+        load_in_8bit=False,
+        tensor_parallel_devices=tensor_parallel_devices,
+        n_steps=10,
     )
     assert isinstance(compute_rps, float) and compute_rps > 0
     network_rps = measure_network_rps(config)

+ 1 - 1
tests/test_block_exact_match.py

@@ -13,7 +13,7 @@ from petals.dht_utils import get_remote_module
 
 
 @pytest.mark.forked
-def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
 

+ 11 - 10
tests/test_remote_sequential.py

@@ -1,6 +1,7 @@
 import pytest
 import torch
-from hivemind import DHT, BatchTensorDescriptor, get_logger
+import torch.nn.functional as F
+from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
 from hivemind.proto import runtime_pb2
 from test_utils import *
 
@@ -39,10 +40,10 @@ def test_remote_sequential():
     assert hidden.shape == test_inputs.shape
     assert hidden.requires_grad
     second_half_outputs = second_half(hidden)
-    assert torch.allclose(second_half_outputs, full_outputs)
+    assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
 
     (second_half_outputs * grad_proj).sum().backward()
-    assert torch.allclose(test_inputs.grad, full_grad)
+    assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
 
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
@@ -58,7 +59,7 @@ def test_remote_sequential():
     assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
     assert abs(approx_outputs - full_outputs).mean() < 0.01
     absmax = abs(full_grad).max()
-    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
+    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
 
 
 class DummyCustomSequenceManager(RemoteSequenceManager):
@@ -87,9 +88,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     remote_sequential = RemoteSequential(config, dht)
 
-    inputs = torch.randn(batch_size, seq_len, config.hidden_size)
-    output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
-    input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
+    output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
+    input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
     intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
 
     input_prompts = input_prompts.detach().requires_grad_(True)
@@ -117,10 +118,10 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
         block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
         (outputs_ref,) = block(outputs_ref)
 
-    assert torch.allclose(outputs_ref, outputs)
+    assert torch.allclose(outputs_ref, outputs, atol=1e-3)
 
     (outputs_ref * output_proj).sum().backward()
     assert input_prompts_ref.grad is not None
-    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
     assert intermediate_prompts_ref.grad is not None
-    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
+    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

+ 46 - 0
tests/test_tensor_parallel.py

@@ -0,0 +1,46 @@
+import random
+
+import pytest
+import torch
+import transformers
+from tensor_parallel import TensorParallel
+from tensor_parallel.slicing_configs import get_bloom_config
+from test_utils import MODEL_NAME
+
+from petals.bloom.from_pretrained import load_pretrained_block
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("custom_config", [True, False])
+@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
+def test_tp_block(devices, custom_config):
+    block_index = random.randint(0, 10)
+    model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
+
+    tp_config = None
+    if custom_config:
+        tp_config = get_bloom_config(model_config, devices)
+
+    batch_size = 2
+    prefix_length = 5
+
+    test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
+    test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
+    test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
+    test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
+    grad_proj = torch.rand_like(test_inputs1)
+
+    y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
+    y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
+    y_ref.backward(grad_proj)
+
+    block_tp = TensorParallel(block, devices, config=tp_config)
+    y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
+    y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
+    y_ours.backward(grad_proj)
+
+    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
+    assert torch.allclose(y_ours, y_ref, atol=1e-6)
+    assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
+    assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)