|
@@ -33,28 +33,32 @@ A stable version of the code and a public swarm open to everyone will be release
|
|
|
|
|
|
## Code examples
|
|
|
|
|
|
-Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
|
|
|
+PETALS integrates seamlessly with PyTorch and Transformers.
|
|
|
+For instance, solving sequence classification with soft prompt tuning of BLOOM-176B looks like this:
|
|
|
|
|
|
```python
|
|
|
-# Initialize distributed BLOOM with soft prompts
|
|
|
-model = AutoModelForPromptTuning.from_pretrained(
|
|
|
- "bigscience/distributed-bloom")
|
|
|
-# Define optimizer for prompts and linear head
|
|
|
-optimizer = torch.optim.AdamW(model.parameters())
|
|
|
+# Initialize distributed BLOOM and connect to the swarm
|
|
|
+model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
+ "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
|
|
|
+) # embeddings & prompts are on your device, transfromer blocks are distributed
|
|
|
+
|
|
|
+print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
|
|
|
|
|
|
+# Training: update only local prompts / adapters
|
|
|
+optimizer = torch.optim.AdamW(model.parameters())
|
|
|
for input_ids, labels in data_loader:
|
|
|
- # Forward pass with local and remote layers
|
|
|
outputs = model.forward(input_ids)
|
|
|
loss = cross_entropy(outputs.logits, labels)
|
|
|
-
|
|
|
- # Distributed backward w.r.t. local params
|
|
|
- loss.backward() # Compute model.prompts.grad
|
|
|
- optimizer.step() # Update local params only
|
|
|
optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
```
|
|
|
|
|
|
## Installation
|
|
|
|
|
|
+__[TO BE UPDATED]__
|
|
|
+
|
|
|
```bash
|
|
|
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
|
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|