浏览代码

quality of life

justheuristic 2 年之前
父节点
当前提交
6700ae16de
共有 3 个文件被更改,包括 70 次插入69 次删除
  1. 60 64
      README.md
  2. 7 2
      cli/run_server.py
  3. 3 3
      requirements.txt

+ 60 - 64
README.md

@@ -35,7 +35,7 @@ This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a s
 ```python
 # Initialize distributed BLOOM and connect to the swarm
 model = DistributedBloomForCausalLM.from_pretrained(
-    "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
+    "bigscience/bloom-petals", tuning_mode="ptune", initial_peers=SEE_BELOW
 )  # Embeddings & prompts are on your device, BLOOM blocks are distributed
 
 print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
@@ -78,90 +78,86 @@ This is important because it's technically possible for peers serving model laye
 
 ## Installation
 
-🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two).
-
---------------------------------------------------------------------------------
-
-```bash
-conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
-pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+Here's how to install the dependencies with conda:
+```
+conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
+pip install bitsandbytes==0.33.2  # for 8-bit quantization
 pip install -r requirements.txt
-pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 ```
 
+This script uses Anaconda to install cuda-enabled PyTorch.
+If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution).
+If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/).
+If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
+
+__OS support:__ currently, PETALS only supports Linux operating systems. On Windows 11, you can run PETALS with GPU enabled inside WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl).
+For macOS, you can *probably* run everything normally if you manage to install dependencies, but we do not guarantee this.
+
+
 ### Basic functionality
 
-All tests is run on localhost
+This is a toy example running on a local machine without GPU with a small bloom model.
+For a more detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
 
-First, run one or more servers like this:
+First, run a couple of servers, each in a separate shell. First server runs like this
 ```bash
-# minimalistic server with non-trained bloom blocks
-python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
-  --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
-# when running multiple servers:
-# - give each server a unique --identity_path (or remote --identity_path arg when debugging)
-# - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
-# - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers
+python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
+  --host_maddrs /ip4/127.0.0.1/tcp/31337   # use port 31337, local connections only
 ```
 
-Then open a python notebook or console and run:
-```python
-import torch
-import hivemind
-from src import DistributedBloomConfig, get_remote_module
-
-
-dht = hivemind.DHT(
-    initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
-    client_mode=True, start=True,
-)
-config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
-layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
-assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
-# test forward/backward, two blocks
-outputs = layer4(layer3(torch.randn(1, 64, 4096)))
-loss = (outputs * torch.randn_like(outputs)).norm()
-loss.backward()
-
-# test inference, one block
-with layer3.inference_session(max_length=10) as sess:
-    for i in range(10):
-        res = sess.step(torch.ones(1, 1, 4096))
+Once you run the server, it will print out a ton of information, including a line like this:
+```bash
+Mon Day 01:23:45.678 [INFO] Running DHT node on ['/ip4/127.0.0.1/tcp/31337/p2p/ALongStringOfCharacters'], initial peers = []
 ```
 
-
-### Convert regular BLOOM into distributed
+You can use this address (/ip4/whatever/else) to connect additional servers. Open another terminal and run:
 ```bash
-
-# convert model from HF hub to a distributed format (can take hours depending on your connection!)
-MY_WRITE_TOKEN=TODO_WRITE_TOKEN_FROM_https://huggingface.co/settings/token
-python -m cli.convert_model --model bigscience/bloom-6b3  \
-  --output_path ./converted_model --output_repo bigscience/test-bloomd-6b3 \
-  --use_auth_token $MY_WRITE_TOKEN  # ^-- todo replace output repo with something you have access to
+python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
+  --host_maddrs /ip4/127.0.0.1/tcp/0 --initial_peers /ip4/127.0...<TODO! copy the address of another server>
+# e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq
 ```
 
+You can assign `--initial_peers` to one or multiple addresses of other servers, not necessarily the first one.
+The only requirement is that at least one of them is alive, i.e. running at the time.
 
-### Test local vs remote block (allclose)
+Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
+make sure your servers have enough total `--num_blocks` to cover that model. 
 
-To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
-```bash
-# shell A: serve model
-python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
-  --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
 
-# shell B:
-export PYTHONPATH=.
-export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
-export MODEL_NAME="bigscience/test-bloomd-6b3"
+Once your have enough servers, you can use them to train and/or inference the model:
+```python
+import torch
+import torch.nn.functional as F
+import transformers
+from src import DistributedBloomForCausalLM
+
+initial_peers = [TODO_put_one_or_more_server_addresses_here]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
+tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
 
-# test individual random blocks for exact match
-pytest tests/test_block_exact_match.py
+model = DistributedBloomForCausalLM.from_pretrained(
+  "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
+)  # this model has only embeddings / logits, all transformer blocks rely on remote servers 
+inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
+remote_outputs = model.generate(inputs, max_length=10)
+print(tokenizer.decode(remote_outputs[0]))  # "a cat sat in the back of the car,"
 
-# test the full model
-pytest tests/test_full_model.py
+model = DistributedBloomForCausalLM.from_pretrained(
+  "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
+)  # this model has only embeddings / logits, all transformer blocks rely on remote servers 
+
+# "train" input embeddings by backprop through distributed transformer blocks
+model.transformer.word_embeddings.weight.requires_grad = True
+outputs = model.forward(input_ids=inputs)
+loss = F.cross_entropy(outputs.logits.flatten(0, 1), inputs.flatten())
+loss.backward()
+print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
 ```
 
+Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
+
+Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running PETALS.
+
+
 --------------------------------------------------------------------------------
 
 <p align="center">

+ 7 - 2
cli/run_server.py

@@ -15,8 +15,11 @@ def main():
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
-    parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
-                        help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument('--converted_model_name_or_path', type=str, default=None,
+                       help="path or name of a pretrained model, converted with cli/convert_model.py")
+    group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
+
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
@@ -83,6 +86,8 @@ def main():
     args = vars(parser.parse_args())
     args.pop("config", None)
 
+    args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
+
     if args.pop("increase_file_limit"):
         increase_file_limit()
 

+ 3 - 3
requirements.txt

@@ -1,8 +1,8 @@
-torch==1.12.0
-bitsandbytes==0.33.0
+torch>=1.12
+bitsandbytes==0.33.0  #TODO update this to 0.33.2 asap
 accelerate==0.10.0
 huggingface-hub==0.7.0
 transformers==4.21.3
 protobuf>=3.12.2,<4.0.0
-https://github.com/learning-at-home/hivemind/archive/131f82c97ea67510d552bb7a68138ad27cbfa5d4.zip
+hivemind==1.1.1
 humanfriendly