5
0
justheuristic 3 жил өмнө
parent
commit
a41f3b8c39

+ 12 - 10
src/bloom/model.py

@@ -11,8 +11,11 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
@@ -442,10 +445,9 @@ class LMHead(nn.Module):
 
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
-        
+
         # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
-        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
-            word_embeddings.device.type == 'cpu':
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
             lm_logits = self.chunked_forward(hidden_states)
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
@@ -454,18 +456,18 @@ class LMHead(nn.Module):
         return lm_logits
 
     def chunked_forward(self, hidden_states):
-        """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. 
-            chunk_size: provides trade-off between efficiency and extra memory consumption. 
+        """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
+        chunk_size: provides trade-off between efficiency and extra memory consumption.
         """
         assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
 
         word_embeddings = self.word_embeddings.weight
         num_embeddings = self.word_embeddings.num_embeddings
 
-        hidden_states = hidden_states.float()    
+        hidden_states = hidden_states.float()
         output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
 
         for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i: i + self.chunk_size].float()
-            output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
+            chunk = word_embeddings[i : i + self.chunk_size].float()
+            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
         return output

+ 6 - 2
src/dht_utils.py

@@ -136,8 +136,12 @@ async def _get_remote_module_infos(
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 state, throughput = server_info.value
-                if not (isinstance(state, int) and isinstance(throughput, float) and
-                        math.isfinite(throughput) and throughput >= 0.0):
+                if not (
+                    isinstance(state, int)
+                    and isinstance(throughput, float)
+                    and math.isfinite(throughput)
+                    and throughput >= 0.0
+                ):
                     raise ValueError(f"Invalid server info: {server_info}")
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
             except (TypeError, ValueError) as e:

+ 4 - 4
src/server/block_selection.py

@@ -9,10 +9,10 @@ def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[Remot
         if module is None:
             throughputs.append(0)
             continue
-        throughputs.append(sum(server.throughput for server in module.servers.values()
-                               if server.state != ServerState.OFFLINE))
+        throughputs.append(
+            sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
+        )
 
-    options = [(sorted(throughputs[i:i + num_blocks]), i)
-               for i in range(0, len(throughputs) - num_blocks + 1)]
+    options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
     best_start = min(options)[1]
     return list(range(best_start, best_start + num_blocks))

+ 14 - 13
src/server/throughput.py

@@ -20,10 +20,10 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-DEFAULT_CACHE_PATH = Path(Path.home(), '.cache', project_name, 'throughput.json')
-DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, 'throughput.lock')
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
+DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
 
-SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], 'cli', 'speed_test.py')
+SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
 
 
 @dataclass
@@ -43,7 +43,7 @@ def get_host_throughput(
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, 'wb') as lock_fd:
+    with open(lock_path, "wb") as lock_fd:
         logger.info("Loading throughput info")
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
@@ -63,7 +63,7 @@ def get_host_throughput(
             info = measure_throughput_info()
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
-                with open(cache_path, 'w') as cache_fd:
+                with open(cache_path, "w") as cache_fd:
                     json.dump(asdict(info), cache_fd)
             except Exception:
                 logger.exception(f"Failed to save throughput info in {cache_path}")
@@ -73,29 +73,30 @@ def get_host_throughput(
 
 
 def measure_throughput_info() -> ThroughputInfo:
-    logger.info("Measuring network, CPU, and GPU throughput. "
-                "This takes about a minute and will be cached for future runs")
+    logger.info(
+        "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
+    )
 
     # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
-    config = BloomConfig.from_pretrained('bigscience/test-bloomd-6b3')
+    config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
 
     network_rps = measure_network_rps(config)
 
-    device_rps = {'cpu': measure_device_rps('cpu', config)}
+    device_rps = {"cpu": measure_device_rps("cpu", config)}
     if torch.cuda.is_available():
-        device_rps['cuda'] = measure_device_rps('cuda', config)
+        device_rps["cuda"] = measure_device_rps("cuda", config)
 
     return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
 
 
 def measure_network_rps(config: BloomConfig) -> float:
-    proc = subprocess.run([SPEED_TEST_PATH, '--json'], capture_output=True)
+    proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
     if proc.returncode != 0:
         raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
     network_info = json.loads(proc.stdout)
 
     bits_per_request = config.hidden_size * 32
-    network_rps = min(network_info['download'], network_info['upload']) / bits_per_request
+    network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
 
     logger.info(
         f"Network throughput: "
@@ -120,7 +121,7 @@ def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n
             elapsed += time.perf_counter() - start_time
         device_rps = n_steps / elapsed
 
-    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == 'cuda' else 'CPU'
+    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
     logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
 
     return device_rps