|
@@ -28,24 +28,26 @@
|
|
|
|
|
|
### Examples
|
|
|
|
|
|
-Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
|
|
|
+Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
|
|
|
+
|
|
|
+This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
|
|
|
|
|
|
```python
|
|
|
-# Initialize distributed BLOOM with soft prompts
|
|
|
-model = AutoModelForPromptTuning.from_pretrained(
|
|
|
- "bigscience/distributed-bloom")
|
|
|
-# Define optimizer for prompts and linear head
|
|
|
-optimizer = torch.optim.AdamW(model.parameters())
|
|
|
+# Initialize distributed BLOOM and connect to the swarm
|
|
|
+model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
+ "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
|
|
|
+) # Embeddings & prompts are on your device, BLOOM blocks are distributed
|
|
|
+
|
|
|
+print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
|
|
|
|
|
|
+# Training (updates only local prompts / adapters)
|
|
|
+optimizer = torch.optim.AdamW(model.parameters())
|
|
|
for input_ids, labels in data_loader:
|
|
|
- # Forward pass with local and remote layers
|
|
|
outputs = model.forward(input_ids)
|
|
|
loss = cross_entropy(outputs.logits, labels)
|
|
|
-
|
|
|
- # Distributed backward w.r.t. local params
|
|
|
- loss.backward() # Compute model.prompts.grad
|
|
|
- optimizer.step() # Update local params only
|
|
|
optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
```
|
|
|
|
|
|
### 🚧 This project is in active development
|
|
@@ -76,6 +78,8 @@ This is important because it's technically possible for peers serving model laye
|
|
|
|
|
|
## Installation
|
|
|
|
|
|
+__[To be updated soon]__
|
|
|
+
|
|
|
```bash
|
|
|
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
|
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|