Sfoglia il codice sorgente

Merge branch 'master' into simplify-running-loop

Alexander Borzunov 2 anni fa
parent
commit
e61544376f

+ 1 - 1
.github/workflows/run-tests.yml

@@ -47,7 +47,7 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - uses: actions/setup-go@v3
       - uses: actions/setup-go@v3
         with:
         with:
-          go-version: '1.16'
+          go-version: '1.18.8'
           check-latest: true
           check-latest: true
       - name: Set up Python
       - name: Set up Python
         uses: actions/setup-python@v2
         uses: actions/setup-python@v2

+ 3 - 0
Dockerfile

@@ -9,11 +9,14 @@ RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment
 # Install packages
 # Install packages
 RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \
 RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \
   build-essential \
   build-essential \
+  curl \
   wget \
   wget \
   git \
   git \
   vim \
   vim \
   && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
   && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
 
 
+RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
+ENV PATH="/root/.cargo/bin:${PATH}"
 RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \
 RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \
   bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
   bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
 ENV PATH="/opt/conda/bin:${PATH}"
 ENV PATH="/opt/conda/bin:${PATH}"

+ 1 - 1
README.md

@@ -24,7 +24,7 @@ large model on hundreds of computers from different universities, companies, and
   Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
   Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
 
 
 To learn more about the ideas behind this library,
 To learn more about the ideas behind this library,
-see the [full list](https://github.com/learning-at-home/hivemind/tree/refer-to-discord-in-docs#citation) of our papers below.
+see the [full list](#citation) of our papers below.
 
 
 ## Example Use Cases
 ## Example Use Cases
 
 

+ 1 - 1
docs/user/contributing.md

@@ -2,7 +2,7 @@
 
 
 This section describes the ways to contribute to the hivemind library. For technical details of developing this library
 This section describes the ways to contribute to the hivemind library. For technical details of developing this library
 and getting towards merging your code in the master branch, read
 and getting towards merging your code in the master branch, read
-the [guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md) in our GitHub repository. In
+the [guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md#) in our GitHub repository. In
 any case, please follow the [Contributor Covenant](https://www.contributor-covenant.org/version/2/0/code_of_conduct/)
 any case, please follow the [Contributor Covenant](https://www.contributor-covenant.org/version/2/0/code_of_conduct/)
 code of conduct when discussing the library and the changes with other community members.
 code of conduct when discussing the library and the changes with other community members.
 
 

+ 2 - 0
hivemind/averaging/averager.py

@@ -8,6 +8,7 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
 import random
 import random
+import signal
 import threading
 import threading
 import weakref
 import weakref
 from dataclasses import asdict
 from dataclasses import asdict
@@ -326,6 +327,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         Starts averager in a background process. if await_ready, this method will wait until background dht
         Starts averager in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         is ready to process incoming requests or for :timeout: seconds max.
         """
         """
+        signal.signal(signal.SIGINT, signal.SIG_IGN)
         self.start()
         self.start()
         if await_ready:
         if await_ready:
             self.wait_until_ready(timeout)
             self.wait_until_ready(timeout)

+ 15 - 6
hivemind/compression/base.py

@@ -80,18 +80,27 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
     compression_type = runtime_pb2.CompressionType.NONE
 
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.detach().numpy()
+        tensor = tensor.detach()
+        dtype_name = str(tensor.dtype).lstrip("torch.")
+        if tensor.dtype == torch.bfloat16:
+            tensor = tensor.to(torch.float32)
+
         return runtime_pb2.Tensor(
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             compression=self.compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
             requires_grad=tensor.requires_grad,
             requires_grad=tensor.requires_grad,
         )
         )
 
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+        if serialized_tensor.dtype == "bfloat16":
+            array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
+            tensor = torch.as_tensor(array, dtype=torch.bfloat16)
+        else:
+            array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+            tensor = torch.as_tensor(array)
+        return tensor.reshape(tuple(serialized_tensor.size))
 
 
     def estimate_compression_ratio(self, info: CompressionInfo) -> float:
     def estimate_compression_ratio(self, info: CompressionInfo) -> float:
         return 1.0
         return 1.0

+ 12 - 5
hivemind/compression/quantization.py

@@ -120,8 +120,8 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
     return np.quantile(partition_quantiles, quantiles)
     return np.quantile(partition_quantiles, quantiles)
 
 
 
 
-BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. 
-Please install it with `pip install bitsandbytes` 
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
+Please install it with `pip install bitsandbytes`
 or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
 or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
 
 
 
 
@@ -139,7 +139,12 @@ class BlockwiseQuantization(Quantization):
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
 
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        tensor = tensor.detach()
+        dtype_name = str(tensor.dtype).lstrip("torch.")
+        if tensor.dtype == torch.bfloat16:
+            tensor = tensor.to(torch.float32)
+
+        quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)
 
 
         serialized_data = (
         serialized_data = (
             np.int64(len(absmax)).tobytes(),
             np.int64(len(absmax)).tobytes(),
@@ -153,7 +158,7 @@ class BlockwiseQuantization(Quantization):
             buffer=b"".join(serialized_data),
             buffer=b"".join(serialized_data),
             size=tensor.shape,
             size=tensor.shape,
             requires_grad=tensor.requires_grad,
             requires_grad=tensor.requires_grad,
-            dtype=tensor.numpy().dtype.name,
+            dtype=dtype_name,
             compression=self.compression_type,
             compression=self.compression_type,
         )
         )
 
 
@@ -172,6 +177,8 @@ class BlockwiseQuantization(Quantization):
         codebook = torch.as_tensor(codebook)
         codebook = torch.as_tensor(codebook)
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
         try:
         try:
-            return dequantize_blockwise(quantized, (absmax, codebook))
+            result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
         except NameError:
         except NameError:
             raise ImportError(BNB_MISSING_MESSAGE)
             raise ImportError(BNB_MISSING_MESSAGE)
+        result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
+        return result

+ 4 - 0
hivemind/hivemind_cli/run_dht.py

@@ -1,5 +1,6 @@
 import time
 import time
 from argparse import ArgumentParser
 from argparse import ArgumentParser
+from secrets import token_hex
 
 
 from hivemind.dht import DHT, DHTNode
 from hivemind.dht import DHT, DHTNode
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -18,6 +19,9 @@ async def report_status(dht: DHT, node: DHTNode):
     logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
     logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
     logger.debug(f"Local storage contents: {node.protocol.storage}")
     logger.debug(f"Local storage contents: {node.protocol.storage}")
 
 
+    # Contact peers and keep the routing table healthy (remove stale PeerIDs)
+    await node.get(f"heartbeat_{token_hex(16)}", latest=True)
+
 
 
 def main():
 def main():
     parser = ArgumentParser()
     parser = ArgumentParser()

+ 1 - 1
hivemind/moe/client/beam_search.py

@@ -171,7 +171,7 @@ class MoEBeamSearcher:
     ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
     ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         """
         """
         :param prefixes: a list of prefix for which to find active successor uids
         :param prefixes: a list of prefix for which to find active successor uids
-        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
+        :param grid_size: if specified, only return successors if they are in range [0, grid_size)
         :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
         :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
         :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
         :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
         :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
         :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix

+ 5 - 1
hivemind/p2p/p2p_daemon.py

@@ -104,6 +104,8 @@ class P2P:
         use_relay_hop: Optional[bool] = None,
         use_relay_hop: Optional[bool] = None,
         use_relay_discovery: Optional[bool] = None,
         use_relay_discovery: Optional[bool] = None,
         check_if_identity_free: bool = True,
         check_if_identity_free: bool = True,
+        no_listen: bool = False,
+        trusted_relays: Optional[Sequence[Union[Multiaddr, str]]] = None,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -171,10 +173,12 @@ class P2P:
             ("bootstrapPeers", initial_peers),
             ("bootstrapPeers", initial_peers),
             ("hostAddrs", host_maddrs),
             ("hostAddrs", host_maddrs),
             ("announceAddrs", announce_maddrs),
             ("announceAddrs", announce_maddrs),
+            ("trustedRelays", trusted_relays),
         ]:
         ]:
             if value:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
                 process_kwargs[param] = self._maddrs_to_str(value)
-
+        if no_listen:
+            process_kwargs["noListenAddrs"] = 1
         if identity_path is not None:
         if identity_path is not None:
             if os.path.isfile(identity_path):
             if os.path.isfile(identity_path):
                 if check_if_identity_free and need_bootstrap:
                 if check_if_identity_free and need_bootstrap:

+ 2 - 1
hivemind/utils/mpfuture.py

@@ -127,7 +127,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     @_state.setter
     @_state.setter
     def _state(self, new_state: State):
     def _state(self, new_state: State):
-        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        with torch.inference_mode():
+            self._shared_state_code[...] = ALL_STATES.index(new_state)
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
             self._set_event_threadsafe()
             self._set_event_threadsafe()
 
 

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 PyYAML
 PyYAML
-torch>=1.6.0
+torch>=1.9.0
 numpy>=1.17
 numpy>=1.17
 scipy>=1.2.1
 scipy>=1.2.1
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1

+ 2 - 2
setup.py

@@ -13,14 +13,14 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.11"
+P2PD_VERSION = "v0.3.16"
 
 
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 
 
 # The value is sha256 of the binary from the release page
 # The value is sha256 of the binary from the release page
 EXECUTABLES = {
 EXECUTABLES = {
-    "p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
+    "p2pd": "057ec61edbe926cf049e9532d43ea9540da55db7b2d8c816d2bbdddce23f3cdf",
 }
 }
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))

+ 17 - 7
tests/test_compression.py

@@ -46,15 +46,18 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
 
 
 
 
+def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+    serialized_tensor = serialize_torch_tensor(tensor, compression)
+    chunks = list(split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = combine_from_streaming(chunks)
+    result = deserialize_torch_tensor(restored)
+    assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
+    assert result.dtype == tensor.dtype
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_serialize_tensor():
 def test_serialize_tensor():
-    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
-        serialized_tensor = serialize_torch_tensor(tensor, compression)
-        chunks = list(split_for_streaming(serialized_tensor, chunk_size))
-        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = combine_from_streaming(chunks)
-        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
-
     tensor = torch.randn(512, 12288)
     tensor = torch.randn(512, 12288)
     for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
     for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
         _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
         _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
@@ -65,6 +68,13 @@ def test_serialize_tensor():
     _check(torch.tensor(1.0), CompressionType.FLOAT16)
     _check(torch.tensor(1.0), CompressionType.FLOAT16)
 
 
 
 
+@pytest.mark.forked
+def test_serialize_bfloat16():
+    tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
+    _check(tensor, CompressionType.NONE)
+    _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_compression():
 def test_allreduce_compression():
     """this test ensures that compression works correctly when multiple tensors have different compression types"""
     """this test ensures that compression works correctly when multiple tensors have different compression types"""

+ 64 - 0
tests/test_relays.py

@@ -0,0 +1,64 @@
+import time
+from functools import partial
+
+import pytest
+
+import hivemind
+
+
+async def ping_to_client(dht, node, peer_id: str):
+    return await node.protocol.call_ping(hivemind.PeerID.from_base58(str(peer_id)))
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "use_auto_relay,use_relay",
+    [
+        (True, True),
+        (False, False),
+    ],
+)
+def test_autorelay(use_auto_relay: bool, use_relay: bool):
+    dht_first_peer = hivemind.DHT(
+        start=True,
+        use_auto_relay=use_auto_relay,
+        use_relay=use_relay,
+        force_reachability="public",
+    )
+    dht_first_peer_id = dht_first_peer.peer_id
+    initial_peers = dht_first_peer.get_visible_maddrs()
+    assert dht_first_peer_id is not None
+
+    dht_third_peer = hivemind.DHT(
+        initial_peers=initial_peers,
+        host_maddrs=[],
+        start=True,
+        no_listen=True,
+        use_relay=use_relay,
+        client_mode=False,
+        use_auto_relay=use_auto_relay,
+    )
+    time.sleep(5)
+    dht_second_peer = hivemind.DHT(
+        initial_peers=initial_peers,
+        start=True,
+        client_mode=False,
+        no_listen=False,
+        use_relay=use_relay,
+        use_auto_relay=use_auto_relay,
+    )
+
+    assert dht_first_peer.is_alive() and dht_second_peer.is_alive() and dht_third_peer.is_alive()
+
+    time_start = time.perf_counter()
+    while time.perf_counter() - time_start < 30:
+        reached_ip = dht_second_peer.run_coroutine(partial(ping_to_client, peer_id=dht_third_peer.peer_id))
+        if reached_ip:
+            assert use_relay
+            break
+        time.sleep(2)
+    else:
+        assert not use_relay
+
+    for peer in dht_first_peer, dht_second_peer, dht_third_peer:
+        peer.shutdown()