Browse Source

minimalistic tests

justheuristic 3 years ago
parent
commit
a3def86444

+ 99 - 0
.github/WORKFLOWS/run-tests.yaml

@@ -0,0 +1,99 @@
+name: Tests
+
+on:
+  push:
+    branches: [ master ]
+  pull_request:
+
+jobs:
+  run_tests:
+
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: [ 3.7, 3.8, 3.9 ]
+    timeout-minutes: 15
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Test
+        run: |
+          cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
+          pytest --durations=0 --durations-min=1.0 -v
+  build_and_test_p2pd:
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-go@v3
+        with:
+          go-version: '1.16'
+          check-latest: true
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install . --global-option=build_py --global-option="--buildgo" --no-use-pep517
+      - name: Test
+        run: |
+          cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
+          pytest -k "p2p" -v
+  codecov_in_develop_mode:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 15
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install -e . --no-use-pep517
+      - name: Test
+        run: |
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
+          pytest --cov hivemind -v tests
+      - name: Upload coverage to Codecov
+        uses: codecov/codecov-action@v1

+ 6 - 0
requirements-dev.txt

@@ -0,0 +1,6 @@
+pytest==6.2.5  # see https://github.com/pytest-dev/pytest/issues/9621
+pytest-forked
+pytest-asyncio==0.16.0
+black==22.3.0
+isort==5.10.1
+psutil

+ 6 - 0
requirements.txt

@@ -0,0 +1,6 @@
+torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+accelerate==0.10.0
+huggingface-hub==0.7.0
+bitsandbytes-cuda113==0.26.0
+https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
+https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 39 - 14
src/client/remote_sequential.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import contextlib
 import logging
 import random
+from typing import Union, Optional
 
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
@@ -25,7 +26,15 @@ class RemoteSequential(nn.Module):
     A sequence of transformer blocks hosted by the swarm.
     """
 
-    def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+    def __init__(
+        self,
+        config: src.DistributedBloomConfig,
+        dht: DHT,
+        prefix: str,
+        max_retries: int = 3,
+        p2p: Optional[P2P] = None,
+        sequence_manager: Optional[RemoteSequenceManager] = None,
+    ):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         if prefix.endswith(UID_DELIMITER):
             logger.warning(
@@ -39,12 +48,17 @@ class RemoteSequential(nn.Module):
         self.dht = dht
         self.prefix = prefix
         self.max_retries = max_retries
-        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-
-        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
-
-        logger.debug(f"Remote block uids: {block_uids}")
-        self.remote_sequence_info = RemoteSequenceManager(dht, block_uids)
+        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
+
+        block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+        if sequence_manager is None:
+            logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids)
+            self.is_subsequence = False
+        else:
+            assert isinstance(sequence_manager.block_uids, list)
+            logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
+            self.is_subsequence = self.sequence_manager.block_uids == block_uids
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -64,21 +78,32 @@ class RemoteSequential(nn.Module):
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
         return inputs
 
-    def __getitem__(self, block_index: int):
-        assert 0 <= block_index < self.config.n_layer
-        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
-        return module
+    def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
+        assert isinstance(ix, (int, slice))
+        if isinstance(ix, int):
+            assert 0 <= ix < self.config.n_layer
+            (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
+            return module
+        else:
+            return RemoteSequential(
+                self.config,
+                self.dht,
+                prefix=self.prefix,
+                max_retries=self.max_retries,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
 
     def __iter__(self):
         for block_index in range(self.config.n_layer):
             yield self[block_index]
 
     def __len__(self):
-        return len(self.remote_sequence_info)
+        return len(self.sequence_manager)
 
     def inference_session(self) -> RemoteSequentialInferenceSession:
-        self.remote_sequence_info.update_()
-        return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
+        self.sequence_manager.update_()
+        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
 
 
 class RemoteSequentialInferenceSession:

+ 16 - 2
src/client/sequence_manager.py

@@ -1,9 +1,9 @@
 from __future__ import annotations
 
 import threading
-from typing import List, Optional, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple, Union
 
-from hivemind import DHT
+from hivemind import DHT, DHTExpiration
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
@@ -21,6 +21,7 @@ class RemoteSequenceManager:
     block_infos: List[Optional[RemoteModuleInfo]]
     spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
     spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
+    last_update_time: DHTExpiration
     lock_changes: threading.Lock
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
@@ -29,6 +30,7 @@ class RemoteSequenceManager:
         self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
+        self.last_update_time = -float("inf")
         self.lock_changes = threading.Lock()
         self.update_()
 
@@ -36,6 +38,18 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
 
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
+        """Get a RemoteSequenceManager for a sub-sequence of blocks"""
+        assert isinstance(ix, (int, slice))
+        if not isinstance(ix, slice):
+            ix = slice(int(ix), int(ix) + 1, 1)
+        with self.lock_changes:
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
+            subseq.block_infos = self.block_infos[ix]
+            subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
+            subseq.last_update_time = self.last_update_time
+        return subseq
+
     def update_(self):
         with self.lock_changes:
             self.update_block_infos_()