瀏覽代碼

Use bitsandbytes==0.34.0, update readme (#76)

* unlock bnb backward
* Fix bnb version in README
* Update requirements.txt
justheuristic 2 年之前
父節點
當前提交
fef48d7d99
共有 4 個文件被更改,包括 8 次插入9 次删除
  1. 2 7
      README.md
  2. 1 1
      requirements.txt
  3. 0 1
      src/server/server.py
  4. 5 0
      src/utils/convert_8bit.py

+ 2 - 7
README.md

@@ -81,7 +81,6 @@ This is important because it's technically possible for peers serving model laye
 Here's how to install the dependencies with conda:
 ```
 conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
-pip install bitsandbytes==0.33.2  # for 8-bit quantization
 pip install -r requirements.txt
 ```
 
@@ -94,12 +93,12 @@ __OS support:__ currently, PETALS only supports Linux operating systems. On Wind
 For macOS, you can *probably* run everything normally if you manage to install dependencies, but we do not guarantee this.
 
 
-### Getting Started
+## Getting Started
 
 This is a toy example running on a local machine without GPU and with a tiny model. 
 For a more detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
 
-First, run a couple of servers, each in a separate shell. First server runs like this
+First, run a couple of servers, each in a separate shell. To launch your first server, run:
 ```bash
 python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
   --host_maddrs /ip4/127.0.0.1/tcp/31337   # use port 31337, local connections only
@@ -146,10 +145,6 @@ inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
 remote_outputs = model.generate(inputs, max_length=10)
 print(tokenizer.decode(remote_outputs[0]))  # "a cat sat in the back of the car,"
 
-model = DistributedBloomForCausalLM.from_pretrained(
-  "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
-)  # this model has only embeddings / logits, all transformer blocks rely on remote servers 
-
 # "train" input embeddings by backprop through distributed transformer blocks
 model.transformer.word_embeddings.weight.requires_grad = True
 outputs = model.forward(input_ids=inputs)

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 torch>=1.12
-bitsandbytes==0.33.0  #TODO update this to 0.33.2 asap
+bitsandbytes==0.34.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
 transformers==4.21.3

+ 0 - 1
src/server/server.py

@@ -202,7 +202,6 @@ class Server(threading.Thread):
 
             if load_in_8bit:
                 dtype = block.input_layernorm.weight.dtype
-                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
 
             block = block.to(device)

+ 5 - 0
src/utils/convert_8bit.py

@@ -1,6 +1,10 @@
+import os
+
 import bitsandbytes as bnb
 import torch
 
+PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
+
 
 def replace_8bit_linear(model, threshold=6.0):
     """
@@ -29,6 +33,7 @@ def replace_8bit_linear(model, threshold=6.0):
                 module.bias is not None,
                 has_fp16_weights=False,
                 threshold=threshold,
+                memory_efficient_backward=PETALS_8BIT_BACKWARD,
             )
             model._modules[n].weight = bnb.nn.Int8Params(
                 module.weight.data, requires_grad=False, has_fp16_weights=False