|
@@ -11,9 +11,9 @@
|
|
|
"\n",
|
|
|
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
|
|
"\n",
|
|
|
- "In this example, we showcase how the Bloom model can be efficiently adapted in a decentralized fashion. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n",
|
|
|
+ "In this example, we showcase how the a test 6B version of BLOOM model can be efficiently adapted in a decentralized fashion using Petals. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n",
|
|
|
"\n",
|
|
|
- "This example will train the Bloom model for chatbot task. On a given dialog context the model have to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)."
|
|
|
+ "This example will train the BLOOM model for the chatbot task. In a given dialogue context the model has to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -21,7 +21,7 @@
|
|
|
"id": "a3f8526f",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Firslt, we have to prepare all dependicies."
|
|
|
+ "First, we have to prepare all dependencies."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -64,7 +64,7 @@
|
|
|
"id": "1bf07b5d",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Set some hyperparameters for training. To setup petals servers, please read \\<link here\\> or use public available one \\<link here\\>."
|
|
|
+ "Set some hyperparameters for training."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -118,7 +118,7 @@
|
|
|
"id": "042e3786",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Prepare personachat dataset. We need two mapping function, one to concatinating history and candidate answers, another for tokenization."
|
|
|
+ "Prepare personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -170,7 +170,7 @@
|
|
|
"id": "ef4323fd",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Before setting up optimizers, check model parameters that will be trained."
|
|
|
+ "Before setting up optimizers, check the model parameters that will be trained."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -190,7 +190,7 @@
|
|
|
"id": "59cffce7",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler."
|
|
|
+ "The optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -212,7 +212,7 @@
|
|
|
"id": "423c56d5",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Let's initialize wandb for logging and start training loop!"
|
|
|
+ "Let's initialize wandb for logging and start the training loop!"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -235,8 +235,7 @@
|
|
|
" }\n",
|
|
|
")\n",
|
|
|
"\n",
|
|
|
- "progress_bar = tqdm(range(len(train_dataloader)))\n",
|
|
|
- "for batch in train_dataloader:\n",
|
|
|
+ "for batch in tqdm(train_dataloader):\n",
|
|
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
|
"\n",
|
|
|
" model.train()\n",
|
|
@@ -248,8 +247,7 @@
|
|
|
" lr_scheduler.step()\n",
|
|
|
" optimizer.zero_grad()\n",
|
|
|
"\n",
|
|
|
- " wandb.log({\"Train Loss\": loss})\n",
|
|
|
- " progress_bar.update(1)"
|
|
|
+ " wandb.log({\"Train Loss\": loss})"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -257,7 +255,10 @@
|
|
|
"id": "0f36cb80",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Try to talk with the trained model! To break from dialog mode, press Ctrl+D."
|
|
|
+ "Try to talk with the trained model! Submit an empty input to stop the execution.\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "In this example we have to pass the whole dialogue. In the future, we will support a much faster interactive dialogue mode."
|
|
|
]
|
|
|
},
|
|
|
{
|