Bladeren bron

Fix floating point issues in block_selection.py (#89)

Alexander Borzunov 2 jaren geleden
bovenliggende
commit
898f614515
2 gewijzigde bestanden met toevoegingen van 25 en 19 verwijderingen
  1. 5 5
      README.md
  2. 20 14
      src/server/block_selection.py

+ 5 - 5
README.md

@@ -60,7 +60,7 @@ A stable version of the code and a public swarm open to everyone will be release
 
 
 ### 📋 Terms of use
 ### 📋 Terms of use
 
 
-Before using Petals to run a language model, please make sure that you are familiar with its terms of use, risks, and limitations. For BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
+Before using Petals to run a language model, please make sure that you are familiar with its terms of use, risks, and limitations. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
 
 
 ### 🔒 Privacy and security
 ### 🔒 Privacy and security
 
 
@@ -101,7 +101,7 @@ For macOS, you can *probably* run everything normally if you manage to install d
 
 
 ## 🚀 Getting Started
 ## 🚀 Getting Started
 
 
-This is a toy example running on a local machine without GPU and with a tiny model. 
+This is a toy example running on a local machine without GPU and with a tiny model.
 For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
 For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
 
 
 First, run a couple of servers, each in a separate shell. To launch your first server, run:
 First, run a couple of servers, each in a separate shell. To launch your first server, run:
@@ -133,7 +133,7 @@ You can assign `--initial_peers` to one or multiple addresses of other servers,
 The only requirement is that at least one of them is running at the time.
 The only requirement is that at least one of them is running at the time.
 
 
 Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
 Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
-make sure your servers have enough total `--num_blocks` to cover that model. 
+make sure your servers have enough total `--num_blocks` to cover that model.
 
 
 Once your have enough servers, you can use them to train and/or inference the model:
 Once your have enough servers, you can use them to train and/or inference the model:
 ```python
 ```python
@@ -162,8 +162,8 @@ print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
 ```
 ```
 
 
 Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning:
 Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning:
-- Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb).
-- A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
+- Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb)
+- A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb)
 
 
 Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.
 Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.
 
 

+ 20 - 14
src/server/block_selection.py

@@ -32,7 +32,10 @@ def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict
         if module is None:
         if module is None:
             continue
             continue
 
 
-        for peer_id, server in module.servers.items():
+        # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
+        # If the order were not defined, we would get slightly different values due to floating point errors,
+        # which may cause excess block replacements.
+        for peer_id, server in sorted(module.servers.items()):
             if server.state == ServerState.OFFLINE:
             if server.state == ServerState.OFFLINE:
                 continue
                 continue
 
 
@@ -47,17 +50,14 @@ def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict
     return spans, throughputs
     return spans, throughputs
 
 
 
 
-def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
-    options = (
-        (sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
-        for i in range(0, len(throughputs) - num_blocks + 1)
-    )
+def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
+    options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1))
     return min(options)[-1]
     return min(options)[-1]
 
 
 
 
 def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
 def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
     _, throughputs = _compute_spans(module_infos)
     _, throughputs = _compute_spans(module_infos)
-    start = _choose_best_start(throughputs, num_blocks, None)
+    start = _choose_best_start(throughputs, num_blocks)
     return list(range(start, start + num_blocks))
     return list(range(start, start + num_blocks))
 
 
 
 
@@ -69,16 +69,22 @@ def should_choose_other_blocks(
 
 
     spans, throughputs = _compute_spans(module_infos)
     spans, throughputs = _compute_spans(module_infos)
     initial_throughput = throughputs.min()
     initial_throughput = throughputs.min()
+    eps = 1e-3
 
 
     assert local_peer_id in spans, "Span served by this server is not present in the DHT"
     assert local_peer_id in spans, "Span served by this server is not present in the DHT"
     local_span = spans[local_peer_id]
     local_span = spans[local_peer_id]
-    throughputs[local_span.start : local_span.end] -= local_span.throughput
+    throughputs[local_span.start : local_span.end] -= local_span.throughput * (1 + eps)
+    # Without (1 + eps) here, we would sometimes subtract a value slightly less than local_span.throughput
+    # due to the floating point error, which would cause excess block replacements.
+    # Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer
+    # the previous server position in case of other things being almost equal.
 
 
-    new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
+    new_start = _choose_best_start(throughputs, local_span.length)
     if local_span.start == new_start:
     if local_span.start == new_start:
         return False  # This server is on its best place already
         return False  # This server is on its best place already
-    local_span.move_to(new_start)
 
 
+    throughputs[local_span.start : local_span.end] += local_span.throughput * eps
+    local_span.move_to(new_start)
     throughputs[local_span.start : local_span.end] += local_span.throughput
     throughputs[local_span.start : local_span.end] += local_span.throughput
 
 
     moved = True
     moved = True
@@ -89,18 +95,18 @@ def should_choose_other_blocks(
         moved = False
         moved = False
         for peer_id in servers:
         for peer_id in servers:
             span = spans[peer_id]
             span = spans[peer_id]
-            throughputs[span.start : span.end] -= span.throughput
+            throughputs[span.start : span.end] -= span.throughput * (1 + eps)
 
 
-            new_start = _choose_best_start(throughputs, span.length, span.start)
+            new_start = _choose_best_start(throughputs, span.length)
+
+            throughputs[span.start : span.end] += span.throughput * eps
             if span.start != new_start:
             if span.start != new_start:
                 span.move_to(new_start)
                 span.move_to(new_start)
                 moved = True
                 moved = True
-
             throughputs[span.start : span.end] += span.throughput
             throughputs[span.start : span.end] += span.throughput
 
 
     new_throughput = throughputs.min()
     new_throughput = throughputs.min()
     actual_quality = initial_throughput / new_throughput
     actual_quality = initial_throughput / new_throughput
     logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
     logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
 
 
-    eps = 1e-6
     return actual_quality < balance_quality - eps
     return actual_quality < balance_quality - eps