Browse Source

Fix psutil-related AccessDenied crash, disable --load_in_8bit by default in case of TP (#188)

* Don't count open fds since it leads to AccessDenied crashes on some machines
* Use --load_in_8bit=False by default in case of tensor parallelism
* Install petals from PyPI in fine-tuning tutorials
Alexander Borzunov 2 years ago
parent
commit
a617ce3cfa

+ 1 - 2
examples/prompt-tuning-personachat.ipynb

@@ -36,8 +36,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "!pip install -q git+https://github.com/bigscience-workshop/petals\n",
-    "!pip install -q datasets wandb"
+    "%pip install -q petals datasets wandb"
    ]
   },
   {

+ 1 - 2
examples/prompt-tuning-sst2.ipynb

@@ -36,8 +36,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "!pip install -q git+https://github.com/bigscience-workshop/petals\n",
-    "!pip install -q datasets wandb"
+    "%pip install -q petals datasets wandb"
    ]
   },
   {

+ 13 - 14
src/petals/server/server.py

@@ -9,7 +9,6 @@ import time
 from typing import Dict, List, Optional, Sequence, Union
 
 import numpy as np
-import psutil
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import check_reachability
-from petals.server.throughput import get_host_throughput
+from petals.server.throughput import get_dtype_name, get_host_throughput
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
@@ -146,12 +145,6 @@ class Server:
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
         self.torch_dtype = torch_dtype
 
-        if load_in_8bit is None:
-            load_in_8bit = device.type == "cuda"
-        if load_in_8bit:
-            logger.info("Model weights will be loaded in 8-bit format")
-        self.load_in_8bit = load_in_8bit
-
         if tensor_parallel_devices is None:
             tensor_parallel_devices = (device,)
         self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
@@ -159,6 +152,17 @@ class Server:
             logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
             check_device_balance(self.tensor_parallel_devices)
 
+        if load_in_8bit is None:
+            load_in_8bit = device.type == "cuda"
+            if load_in_8bit and len(self.tensor_parallel_devices) > 1:
+                load_in_8bit = False
+                logger.warning(
+                    "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
+                    "You can explicitly set `--load_in_8bit True` to override this"
+                )
+        self.load_in_8bit = load_in_8bit
+        logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
+
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
             num_blocks = self._choose_num_blocks()
@@ -167,8 +171,7 @@ class Server:
                 first_block_index, last_block_index = block_indices.split(":")
                 first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
             except Exception as e:
-                logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
-                raise
+                raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
             block_indices = range(first_block_index, last_block_index)
             num_blocks = len(block_indices)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@@ -301,10 +304,6 @@ class Server:
         del self.module_container
         gc.collect()  # In particular, this closes unused file descriptors
 
-        cur_proc = psutil.Process()
-        num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
-        logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
-
         if self.device.type == "cuda":
             torch.cuda.empty_cache()