<div>
<img src="https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67" width="40%">  
</div>

# Distributed Bloom for Text Classification using Prompt Tuning

In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.

We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.

To use this notebook in Colab:

1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator.

First, we have to prepare all dependencies.

In [None]:
%pip install -q petals datasets wandb scikit-learn

In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import wandb
from datasets import load_dataset, load_metric
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BloomTokenizerFast, get_scheduler

from petals import DistributedBloomForSequenceClassification

Let's set some hyperparameters for training:

In [None]:
# Choose a model you'd like to prompt-tune. We recommend starting with
# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.
# Once your code is ready, you can switch to full-scale
# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).
MODEL_NAME = "bigscience/bloom-7b1-petals"

# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').
# The latter fine-tunes separate prefixes for each transformer block,
# so prompt-tuning will take more time but yield better results.
# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf
TUNING_MODE = 'ptune'

NUM_PREFIX_TOKENS = 16
DEVICE = 'cuda'
BATCH_SIZE = 16
LR = 1e-2
WEIGHT_DECAY = 0.0
NUM_EPOCHS = 3
SEED = 42
MODEL_MAX_LENGTH = 64

Prepare tokenizer and distributed model, connect it to servers.

In [None]:
tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForSequenceClassification.from_pretrained(
    MODEL_NAME,
    pre_seq_len=NUM_PREFIX_TOKENS,
    tuning_mode=TUNING_MODE
).to(DEVICE)

Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset.

In [None]:
task = 'sst2'

dataset = load_dataset("glue", task)

def preprocess_function(examples):
    return tokenizer(examples["sentence"], padding='max_length', truncation=True)

tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx", "attention_mask"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

train_dataset = tokenized_datasets["train"].shuffle(seed=SEED)
valid_dataset = tokenized_datasets["validation"].shuffle(seed=SEED)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library.

In [None]:
metric = load_metric('glue', task)

def eval_metrics(model, dataloader, device='cpu'):
    model.eval()
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    model.train()
    return metric.compute()

Before setting up optimizers, check the model parameters that will be trained.

In [None]:
for n, p in model.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad, p.device)

The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler.

In [None]:
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)
)

Let's initialize wandb for logging and start the training loop!

In [None]:
wandb.init(
    project="bloom-sst-2",
    config={
        "num_epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "weight_decay": WEIGHT_DECAY,
        "num_prefix_tokens": NUM_PREFIX_TOKENS,
        "model_name": MODEL_NAME,
        "seed": SEED,
    }
)

for epoch in range(NUM_EPOCHS):
    for batch in tqdm(train_dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        model.train()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        wandb.log({"Train Loss": loss})

    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)
    wandb.log({"Valid Accuracy": accuracy}, commit=False)

Our model have been trained!

## Beyond soft-prompt tuning

Let's try to tune model using adapters in the middle of the model.

In [None]:
class BloomBasedClassifier(nn.Module):
  def __init__(
      self,
      model,
      intermediate_size: int = 32,
      num_classes: int = 2,
      adapter_layer_position: int = 6,
      head_layer_position: int = 10
    ):
    super().__init__()
    self.distributed_layers = model.transformer.h

    self.hidden_size = model.config.hidden_size
    self.dtype = model.config.torch_dtype
    self.intermediate_size = intermediate_size
    self.num_classes = num_classes
    self.adapter_layer_position = adapter_layer_position
    self.head_layer_position = head_layer_position
    
    self.word_embeddings = model.transformer.word_embeddings
    self.adapter = nn.Sequential(
        nn.Linear(self.hidden_size, self.intermediate_size),
        nn.Linear(self.intermediate_size, self.hidden_size),
    ).to(self.dtype)
    self.head = nn.Sequential(
        nn.LayerNorm(self.hidden_size),
        nn.Linear(self.hidden_size, self.num_classes),
    ).to(self.dtype)
  
  def forward(self, embeddings):
    before_layers = self.distributed_layers[0:self.adapter_layer_position]
    after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]
    
    hidden_states = before_layers(embeddings)
    hidden_states = self.adapter(hidden_states)
    hidden_states = after_layers(hidden_states)
    pooled_states = torch.mean(hidden_states, dim=1)
    return self.head(pooled_states)

Clear model and device memory.

In [None]:
del model, optimizer, lr_scheduler
torch.cuda.empty_cache()

Create new model with adapters.

In [None]:
INTERMEDIATE_SIZE = 32
ADAPTER_LAYER_POSITION = 6
HEAD_LAYER_POSITION = 10

In [None]:
cls_model = BloomBasedClassifier(
    DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),
    intermediate_size=INTERMEDIATE_SIZE,
    adapter_layer_position=ADAPTER_LAYER_POSITION,
    head_layer_position=HEAD_LAYER_POSITION,
).to(DEVICE)
cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
cls_criterion = nn.CrossEntropyLoss()

lr_scheduler = get_scheduler(
    name="linear", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)
)

And start training our new adapted model.

In [None]:
wandb.init(
    project="bloom_based_cls-sst-2",
    config={
        "num_epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "weight_decay": WEIGHT_DECAY,
        "model_name": MODEL_NAME,
        "seed": SEED,
        "intermediate_size": INTERMEDIATE_SIZE,
        "adapter_layer_position": ADAPTER_LAYER_POSITION,
        "head_layer_position": HEAD_LAYER_POSITION,
    }
)

for epoch in range(NUM_EPOCHS):
    for batch in tqdm(train_dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        cls_model.train()
        with torch.no_grad():
            embeddings_output = cls_model.word_embeddings(batch["input_ids"])
        outputs = cls_model(embeddings_output)
        loss = cls_criterion(outputs, batch["labels"])
        loss.backward()

        cls_optimizer.step()
        lr_scheduler.step()
        cls_optimizer.zero_grad()

        wandb.log({"Train Loss": loss})

    accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)
    wandb.log({"Valid Accuracy": accuracy}, commit=False)