Quellcode durchsuchen

Merge branch 'nasty_bug_with_extra_tensors' of github.com:learning-at-home/hivemind into power_ef_new

Artem Chumachenko vor 3 Jahren
Ursprung
Commit
9b070e0a5a

+ 3 - 1
.github/workflows/check-style.yml

@@ -13,7 +13,7 @@ jobs:
       - uses: psf/black@stable
         with:
           options: "--check --diff"
-          version: "21.6b0"
+          version: "22.1.0"
   isort:
     runs-on: ubuntu-latest
     steps:
@@ -22,3 +22,5 @@ jobs:
         with:
           python-version: 3.8
       - uses: isort/isort-action@master
+        with:
+          isortVersion: "5.10.1"

+ 4 - 1
.github/workflows/run-tests.yml

@@ -35,6 +35,7 @@ jobs:
       - name: Test
         run: |
           cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
           pytest --durations=0 --durations-min=1.0 -v
   build_and_test_p2pd:
     runs-on: ubuntu-latest
@@ -61,6 +62,7 @@ jobs:
       - name: Test
         run: |
           cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
           pytest -k "p2p" -v
   codecov_in_develop_mode:
 
@@ -87,6 +89,7 @@ jobs:
           pip install -e . --no-use-pep517
       - name: Test
         run: |
-          pytest --cov=hivemind -v tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
+          pytest --cov hivemind -v tests
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1

+ 4 - 3
examples/albert/README.md

@@ -55,9 +55,7 @@ To join the collaboration with a GPU trainer,
   (see [default paths](./arguments.py#L117-L134) for reference)
 - Run:
   ```bash
-  ./run_trainer.py \
-      --initial_peers ONE_OR_MORE_PEERS \
-      --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+  ./run_trainer.py  --initial_peers ONE_OR_MORE_PEERS --per_device_train_batch_size BATCH_SIZE_FOR_YOUR_GPU
   ```
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
@@ -82,6 +80,9 @@ To join the collaboration with a GPU trainer,
   You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
   If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.
 
+  The `BATCH_SIZE_FOR_YOUR_GPU` should be tweaked so that the model fits into your GPU memory.
+  For 1080Ti or 2080Ti gpus, a good initial value is 4. For 8GB GPUs, try batch size 1-2.
+
 See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.
 
 As the peer begins training, it will periodically report training logs in the following form:

+ 9 - 4
examples/albert/arguments.py

@@ -6,7 +6,7 @@ from transformers import TrainingArguments
 
 @dataclass
 class BaseTrainingArguments:
-    experiment_prefix: str = field(
+    run_id: str = field(
         default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
     )
     initial_peers: List[str] = field(
@@ -127,7 +127,7 @@ class AlbertTrainingArguments(TrainingArguments):
     gradient_accumulation_steps: int = 2
     seq_length: int = 512
 
-    max_steps: int = 125_000  # please note: this affects both number of steps and learning rate schedule
+    total_steps: int = 125_000  # please note: this only affects the learning rate schedule
     learning_rate: float = 0.00176
     warmup_steps: int = 5000
     adam_epsilon: float = 1e-6
@@ -138,9 +138,14 @@ class AlbertTrainingArguments(TrainingArguments):
     fp16: bool = True
     fp16_opt_level: str = "O2"
     do_train: bool = True
+    do_eval: bool = False
 
+    logging_dir: str = "logs"
+    output_dir: str = "outputs"
     logging_steps: int = 100
+    logging_first_step: bool = True
+    overwrite_output_dir: bool = True
+
     save_total_limit: int = 2
     save_steps: int = 500
-
-    output_dir: str = "outputs"
+    max_steps: int = 10**30  # meant as "peer should compute gradients forever"

+ 1 - 1
examples/albert/requirements.txt

@@ -4,4 +4,4 @@ torch_optimizer==0.1.0
 wandb==0.10.26
 sentencepiece
 requests
-nltk==3.6.5
+nltk==3.6.7

+ 3 - 3
examples/albert/run_trainer.py

@@ -215,7 +215,7 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
+    validators, local_public_key = utils.make_validators(collaboration_args.run_id)
 
     dht = DHT(
         start=True,
@@ -260,12 +260,12 @@ def main():
     ]
 
     scheduler = lambda opt: get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
     )
 
     optimizer = Optimizer(
         dht=dht,
-        run_id=collaboration_args.experiment_prefix,
+        run_id=collaboration_args.run_id,
         target_batch_size=adjusted_target_batch_size,
         batch_size_per_step=total_batch_size_per_step,
         optimizer=opt,

+ 7 - 6
examples/albert/run_training_monitor.py

@@ -9,7 +9,7 @@ import requests
 import torch
 import wandb
 from torch_optimizer import Lamb
-from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
 
 import hivemind
 from hivemind.optim.state_averager import TrainingStateAverager
@@ -40,6 +40,7 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     wandb_project: Optional[str] = field(
         default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
     )
+    store_checkpoints: bool = field(default=True, metadata={"help": "If False, disables periodic checkpoint saving"})
     save_checkpoint_step_interval: int = field(
         default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
     )
@@ -56,7 +57,6 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     upload_interval: Optional[float] = field(
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
-    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 class CheckpointHandler:
@@ -99,7 +99,8 @@ class CheckpointHandler:
         self.state_averager = TrainingStateAverager(
             dht=dht,
             optimizer=opt,
-            prefix=experiment_prefix,
+            scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
+            prefix=f"{run_id}_state_averager",
             state_compression=hivemind.Float16Compression(),
             bandwidth=optimizer_args.bandwidth,
             client_mode=optimizer_args.client_mode,
@@ -155,8 +156,8 @@ if __name__ == "__main__":
         version = ip_address(address).version
         monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
-    experiment_prefix = monitor_args.experiment_prefix
-    validators, local_public_key = utils.make_validators(experiment_prefix)
+    run_id = monitor_args.run_id
+    validators, local_public_key = utils.make_validators(run_id)
 
     dht = hivemind.DHT(
         start=True,
@@ -177,7 +178,7 @@ if __name__ == "__main__":
         checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
     while True:
-        metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
+        metrics_dict = dht.get(run_id + "_metrics", latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
             metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]

+ 2 - 2
examples/albert/utils.py

@@ -24,9 +24,9 @@ class MetricSchema(BaseModel):
     metrics: Dict[BytesWithPublicKey, LocalMetrics]
 
 
-def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
-    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
+    validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
     return validators, signature_validator.local_public_key
 
 

+ 1 - 1
hivemind/averaging/load_balancing.py

@@ -65,7 +65,7 @@ def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int =
     # the constraints below are tuples (A, b) such that Ax <= b
     nonnegative_weights = -np.eye(group_size, num_variables, dtype=c.dtype), np.zeros(group_size, c.dtype)
     weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
-    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10 ** -LOAD_BALANCING_LP_DECIMALS)
+    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10**-LOAD_BALANCING_LP_DECIMALS)
     coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1), c.dtype)])
     xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / bandwidths[is_nonzero]
     force_max_weights = np.eye(group_size, M=num_variables, dtype=c.dtype), is_nonzero.astype(c.dtype)

+ 1 - 1
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 19
+DEFAULT_PART_SIZE_BYTES = 2**19
 logger = get_logger(__name__)
 
 

+ 2 - 2
hivemind/compression/quantization.py

@@ -48,7 +48,7 @@ class Quantization(CompressionBase, ABC):
 
     @property
     def n_bins(self):
-        return 2 ** self.n_bits
+        return 2**self.n_bits
 
 
 class Uniform8BitQuantization(Quantization):
@@ -94,7 +94,7 @@ def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
     return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
 
 
-def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
     """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
     if not array.data.c_contiguous and array.data.f_contiguous:
         array = array.T

+ 1 - 1
hivemind/moe/client/switch_moe.py

@@ -150,7 +150,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
         balancing_loss = torch.stack(
             [
-                torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
+                torch.mean(dim_softmax.mean(0) * dim_utilization) * dim_size**2
                 for dim_softmax, dim_utilization, dim_size in zip(
                     grid_softmax, self.grid_utilization, self.beam_search.grid_size
                 )

+ 13 - 4
hivemind/optim/optimizer.py

@@ -453,7 +453,8 @@ class Optimizer(torch.optim.Optimizer):
 
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 if not began_averaging_gradients:
-                    pass  # failed to start gradient averaging due to an internal error
+                    # failed to start gradient averaging due to an internal error
+                    self.grad_averager.load_accumulators_into_averager_()
                 elif self.delay_grad_averaging:
                     # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
                     wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
@@ -540,6 +541,7 @@ class Optimizer(torch.optim.Optimizer):
                 self._tag_along_with_zero_weight(self.scheduled_grads)
             else:
                 logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self._load_local_gradients_into_optimizer()
                 self.scheduled_grads.cancel()
             self.scheduled_grads = None
         return began_averaging_gradients
@@ -608,9 +610,7 @@ class Optimizer(torch.optim.Optimizer):
             logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
 
         if not averaged_gradients:
-            logger.log(self.status_loglevel, f"Proceeding with local gradients")
-            self.grad_averager.load_accumulators_into_averager_()
-            self._load_averaged_gradients_into_optimizer_()
+            self._load_local_gradients_into_optimizer()
 
     def _load_averaged_gradients_into_optimizer_(self):
         """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
@@ -629,6 +629,15 @@ class Optimizer(torch.optim.Optimizer):
 
         self.grad_averager.notify_used_averaged_gradients()
 
+    def _load_local_gradients_into_optimizer(self):
+        """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
+        logger.log(self.status_loglevel, f"Proceeding with local gradients")
+        self.grad_averager.load_accumulators_into_averager_()
+        # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
+        # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
+        # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
+        self._load_averaged_gradients_into_optimizer_()
+
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:

+ 1 - 1
hivemind/p2p/p2p_daemon.py

@@ -265,7 +265,7 @@ class P2P:
         return self._daemon_listen_maddr
 
     @staticmethod
-    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
+    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2**16) -> None:
         writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
         data = memoryview(data)
         for offset in range(0, len(data), chunk_size):

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -26,7 +26,7 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
-DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+DEFAULT_MAX_MSG_SIZE = 4 * 1024**2
 
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:

+ 1 - 1
hivemind/utils/grpc.py

@@ -175,7 +175,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
-STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+STREAMING_CHUNK_SIZE_BYTES = 2**16
 
 
 def split_for_streaming(

+ 1 - 1
hivemind/utils/limits.py

@@ -3,7 +3,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+def increase_file_limit(new_soft=2**15, new_hard=2**15):
     """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     try:
         import resource  # local import to avoid ImportError for Windows users

+ 2 - 0
hivemind/utils/mpfuture.py

@@ -18,6 +18,8 @@ from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
+torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system"))
+
 # flavour types
 ResultType = TypeVar("ResultType")
 PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.black]
 line-length = 119
-required-version = "21.6b0"
+required-version = "22.1.0"
 
 [tool.isort]
 profile = "black"

+ 5 - 4
requirements-dev.txt

@@ -1,10 +1,11 @@
-pytest
+pytest==6.2.5  # see https://github.com/pytest-dev/pytest/issues/9621
 pytest-forked
-pytest-asyncio
+pytest-asyncio==0.16.0
 pytest-cov
+coverage==6.0.2  # see https://github.com/pytest-dev/pytest-cov/issues/520
 tqdm
 scikit-learn
 torchvision
-black==21.6b0
-isort
+black==22.1.0
+isort==5.10.1
 psutil

+ 9 - 9
tests/test_allreduce.py

@@ -33,7 +33,7 @@ async def test_partitioning():
 
     # note: this test does _not_ use parameterization to reuse sampled tensors
     for num_tensors in 1, 3, 5:
-        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+        for part_size_bytes in 31337, 2**20, 10**10:
             for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
                 tensors = random.choices(all_tensors, k=num_tensors)
                 partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
@@ -157,16 +157,16 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 @pytest.mark.parametrize(
     "peer_modes, averaging_weights, peer_fractions, part_size_bytes",
     [
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
-        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
-        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2**20),
+        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2**20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
+        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
+        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2**20),
+        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2**20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2**20),
         ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 256),
         ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 19),
-        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2 ** 20),
+        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2**20),
     ],
 )
 @pytest.mark.forked

+ 1 - 1
tests/test_allreduce_fault_tolerance.py

@@ -158,7 +158,7 @@ def test_fault_tolerance(fault0: Fault, fault1: Fault):
             min_matchmaking_time=1.0,
             next_chunk_timeout=0.5,
             allreduce_timeout=5,
-            part_size_bytes=2 ** 16,
+            part_size_bytes=2**16,
             client_mode=(i == 1),
             start=True,
             fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,

+ 1 - 1
tests/test_averaging.py

@@ -283,7 +283,7 @@ def test_load_balancing():
         load_balance_peers(100, (0, 0, 0))
 
     for i in range(10):
-        vector_size = np.random.randint(1, 1024 ** 3)
+        vector_size = np.random.randint(1, 1024**3)
         num_peers = np.random.randint(1, 256)
         scale = 1e-9 + np.random.rand() * 1e5
         bandwidths = np.random.rand(num_peers) * scale + 1e-6

+ 1 - 1
tests/test_compression.py

@@ -53,7 +53,7 @@ def test_serialize_tensor():
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
 
     tensor = torch.randn(512, 12288)
-    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
         _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
 
     _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)

+ 2 - 2
tests/test_dht.py

@@ -72,7 +72,7 @@ async def dummy_dht_coro_stateful(self, node):
 
 async def dummy_dht_coro_long(self, node):
     await asyncio.sleep(0.25)
-    return self._x_dummy ** 2
+    return self._x_dummy**2
 
 
 async def dummy_dht_coro_for_cancel(self, node):
@@ -94,7 +94,7 @@ def test_run_coroutine():
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
     assert not hasattr(dht, "_x_dummy")
-    assert bg_task.result() == 126 ** 2
+    assert bg_task.result() == 126**2
 
     future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
     time.sleep(0.25)

+ 19 - 3
tests/test_optimizer.py

@@ -149,34 +149,50 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
 
 
 @pytest.mark.forked
-def test_load_state_from_peers():
+@pytest.mark.parametrize("dpu", [True, False])
+def test_load_state_from_peers(dpu: bool):
     dht1 = hivemind.DHT(start=True)
     dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
 
     model1 = nn.Linear(2, 3)
     model2 = nn.Linear(2, 3)
 
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
     common_kwargs = dict(
         optimizer=partial(torch.optim.SGD, lr=0.1),
         scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        offload_optimizer=dpu,
+        reuse_tensors=dpu,
         target_group_size=2,
         prefix="my_exp",
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        extra_tensors=extras1,
+        **common_kwargs,
     )
 
-    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
+    avgr2 = TrainingStateAverager(
+        dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2, **common_kwargs
+    )
 
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42
+    extras2[0][:] = 9999
     time.sleep(0.1)
 
     avgr1.load_state_from_peers()
     assert avgr1.local_epoch == 1337
     assert torch.all(model1.weight == 42).item()
     assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+    assert torch.all(extras1[0] == extras2[0]).item() and torch.all(extras1[0] == extras2[0]).item()
+    assert torch.all(extras1[0] == 9999).item()
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_p2p_daemon.py

@@ -89,7 +89,7 @@ async def test_unary_handler_edge_cases():
     p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
 
     async def square_handler(data: test_pb2.TestRequest, context):
-        return test_pb2.TestResponse(number=data.number ** 2)
+        return test_pb2.TestResponse(number=data.number**2)
 
     await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
 
@@ -202,7 +202,7 @@ async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.
             except asyncio.IncompleteReadError:
                 break
 
-            result = x ** 2
+            result = x**2
 
             await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
 
@@ -215,7 +215,7 @@ async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.S
             await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
             result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
 
-            assert result == x ** 2
+            assert result == x**2
 
 
 @pytest.mark.asyncio

+ 7 - 7
tests/test_p2p_daemon_bindings.py

@@ -38,15 +38,15 @@ PAIRS_INT_SERIALIZED_VALID = (
     (0, b"\x00"),
     (1, b"\x01"),
     (128, b"\x80\x01"),
-    (2 ** 32, b"\x80\x80\x80\x80\x10"),
-    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+    (2**32, b"\x80\x80\x80\x80\x10"),
+    (2**64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
 )
 
 PAIRS_INT_SERIALIZED_OVERFLOW = (
-    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
-    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2**64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2**64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
     (
-        2 ** 128,
+        2**128,
         b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
     ),
 )
@@ -94,7 +94,7 @@ async def test_write_unsigned_varint_overflow(integer):
         await write_unsigned_varint(s, integer)
 
 
-@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.parametrize("integer", (-1, -(2**32), -(2**64), -(2**128)))
 @pytest.mark.asyncio
 async def test_write_unsigned_varint_negative(integer):
     s = MockWriter()
@@ -125,7 +125,7 @@ async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
     Test edge cases with different `max_bits`
     """
     for i in range(-3, 0):
-        integer = i + (2 ** max_bits)
+        integer = i + 2**max_bits
         s = MockReaderWriter()
         await write_unsigned_varint(s, integer, max_bits=max_bits)
         s.seek(0, 0)

+ 3 - 3
tests/test_p2p_servicer.py

@@ -21,7 +21,7 @@ async def server_client():
 async def test_unary_unary(server_client):
     class ExampleServicer(ServicerBase):
         async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
-            return test_pb2.TestResponse(number=request.number ** 2)
+            return test_pb2.TestResponse(number=request.number**2)
 
     server, client = server_client
     servicer = ExampleServicer()
@@ -83,8 +83,8 @@ async def test_stream_stream(server_client):
             self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
             async for item in stream:
-                yield test_pb2.TestResponse(number=item.number ** 2)
-                yield test_pb2.TestResponse(number=item.number ** 3)
+                yield test_pb2.TestResponse(number=item.number**2)
+                yield test_pb2.TestResponse(number=item.number**3)
 
     server, client = server_client
     servicer = ExampleServicer()

+ 6 - 3
tests/test_util_modules.py

@@ -313,6 +313,8 @@ def test_many_futures():
     p.start()
 
     some_fork_futures = receiver.recv()
+
+    time.sleep(0.1)  # giving enough time for the futures to be destroyed
     assert len(hivemind.MPFuture._active_futures) == 700
 
     for future in some_fork_futures:
@@ -323,6 +325,7 @@ def test_many_futures():
     evt.set()
     for future in main_futures:
         future.cancel()
+    time.sleep(0.1)  # giving enough time for the futures to be destroyed
     assert len(hivemind.MPFuture._active_futures) == 0
     p.join()
 
@@ -394,7 +397,7 @@ def test_split_parts():
     chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
     assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
 
-    chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
+    chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9))
     assert len(chunks3) == 1
 
     compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
@@ -437,8 +440,8 @@ async def test_asyncio_utils():
     assert res == list(range(len(res)))
 
     num_steps = 0
-    async for elem in amap_in_executor(lambda x: x ** 2, as_aiter(*range(100)), max_prefetch=5):
-        assert elem == num_steps ** 2
+    async for elem in amap_in_executor(lambda x: x**2, as_aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps**2
         num_steps += 1
     assert num_steps == 100