Browse Source

Update README.md

justheuristic 3 years ago
parent
commit
83523c18b5
1 changed files with 15 additions and 11 deletions
  1. 15 11
      README.md

+ 15 - 11
README.md

@@ -33,28 +33,32 @@ A stable version of the code and a public swarm open to everyone will be release
 
 ## Code examples
 
-Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
+PETALS integrates seamlessly with PyTorch and Transformers.
+For instance, solving sequence classification with soft prompt tuning of BLOOM-176B looks like this:
 
 ```python
-# Initialize distributed BLOOM with soft prompts
-model = AutoModelForPromptTuning.from_pretrained(
-       "bigscience/distributed-bloom")
-# Define optimizer for prompts and linear head
-optimizer = torch.optim.AdamW(model.parameters())
+# Initialize distributed BLOOM and connect to the swarm
+model = DistributedBloomForCausalLM.from_pretrained(
+    "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
+)  # embeddings & prompts are on your device, transfromer blocks are distributed
+
+print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
 
+# Training: update only local prompts / adapters
+optimizer = torch.optim.AdamW(model.parameters())
 for input_ids, labels in data_loader:
-    # Forward pass with local and remote layers
     outputs = model.forward(input_ids)
     loss = cross_entropy(outputs.logits, labels)
-
-    # Distributed backward w.r.t. local params
-    loss.backward() # Compute model.prompts.grad
-    optimizer.step() # Update local params only
     optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
 ```
 
 ## Installation
 
+__[TO BE UPDATED]__
+
 ```bash
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html