|
@@ -11,9 +11,9 @@
|
|
|
"\n",
|
|
|
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
|
|
"\n",
|
|
|
- "In this example, we showcase how the a test 6B version of BLOOM model can be efficiently adapted in a decentralized fashion using Petals. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n",
|
|
|
+ "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
|
|
"\n",
|
|
|
- "This example will train the BLOOM model for the chatbot task. In a given dialogue context the model has to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)."
|
|
|
+ "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -36,21 +36,18 @@
|
|
|
"import os\n",
|
|
|
"import sys\n",
|
|
|
"sys.path.insert(0, \"..\")\n",
|
|
|
- "\n",
|
|
|
- "# General \n",
|
|
|
+ " \n",
|
|
|
"import torch\n",
|
|
|
+ "import transformers\n",
|
|
|
+ "import wandb\n",
|
|
|
+ "from datasets import load_dataset\n",
|
|
|
"from tqdm import tqdm\n",
|
|
|
"from torch.optim import AdamW\n",
|
|
|
"from torch.utils.data import DataLoader\n",
|
|
|
+ "from transformers import get_scheduler\n",
|
|
|
"\n",
|
|
|
- "# Distributed\n",
|
|
|
- "from src.client.remote_model import DistributedBloomForCausalLM\n",
|
|
|
- "\n",
|
|
|
- "# HF imports\n",
|
|
|
- "import transformers\n",
|
|
|
- "import wandb\n",
|
|
|
- "from datasets import load_dataset\n",
|
|
|
- "from transformers import get_scheduler"
|
|
|
+ "# Import a Petals model\n",
|
|
|
+ "from src.client.remote_model import DistributedBloomForCausalLM"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -58,7 +55,7 @@
|
|
|
"id": "1bf07b5d",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Set some hyperparameters for training."
|
|
|
+ "Let's set some hyperparameters for training:"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -71,7 +68,7 @@
|
|
|
"MODEL_NAME = ... # select model you like\n",
|
|
|
"INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
|
|
|
"NUM_PREFIX_TOKENS = 16\n",
|
|
|
- "DEVICE ='cpu'\n",
|
|
|
+ "DEVICE = 'cpu'\n",
|
|
|
"BATCH_SIZE = 4\n",
|
|
|
"LR = 1e-2\n",
|
|
|
"WEIGHT_DECAY = 0.0\n",
|
|
@@ -112,7 +109,7 @@
|
|
|
"id": "042e3786",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Prepare personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
|
|
|
+ "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -184,7 +181,7 @@
|
|
|
"id": "59cffce7",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "The optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler."
|
|
|
+ "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -252,7 +249,7 @@
|
|
|
"Try to talk with the trained model! Submit an empty input to stop the execution.\n",
|
|
|
"\n",
|
|
|
"\n",
|
|
|
- "In this example we have to pass the whole dialogue. In the future, we will support a much faster interactive dialogue mode."
|
|
|
+ "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
|
|
|
]
|
|
|
},
|
|
|
{
|