{ "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": {}, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed Bloom for Text Generation using Prompt Tuning\n", "\n", "In this example, we showcase how the Bloom model can be efficiently adapted in a decentralized fashion. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n", "\n", "This example will train the Bloom model for chatbot task. On a given dialog context the model have to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)." ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": {}, "source": [ "Firslt, we have to prepare all dependicies." ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": {}, "outputs": [], "source": [ "%env OMP_NUM_THREADS=24\n", "\n", "import os\n", "import sys\n", "import inspect\n", "sys.path.insert(0, \"..\")\n", "\n", "# General \n", "import torch\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "\n", "# Distributed\n", "import hivemind\n", "from src.client.remote_model import DistributedBloomForCausalLM\n", "\n", "# HF imports\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset\n", "from transformers import get_scheduler\n", "\n", "# Visualization dependencies\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": {}, "source": [ "Set some hyperparameters for training. To setup petals servers, please read \\ or use public available one \\." ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME = ... # select model you like\n", "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE ='cpu'\n", "BATCH_SIZE = 4\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_SAMPLES = 1000\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 256\n", "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": {}, "source": [ "Prepare tokenizer and distributed model, connect it to servers." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": {}, "outputs": [], "source": [ "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "model = DistributedBloomForCausalLM.from_pretrained(\n", " MODEL_NAME, \n", " initial_peers=INITIAL_PEERS, \n", " pre_seq_len=NUM_PREFIX_TOKENS, \n", " tuning_mode=TUNING_MODE\n", ").to(DEVICE)" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": {}, "source": [ "Prepare personachat dataset. We need two mapping function, one to concatinating history and candidate answers, another for tokenization." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"bavard/personachat_truecased\")\n", "\n", "\n", "def chunking(examples):\n", " inputs = [\n", " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n", " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n", " for candidate in candidates\n", " ]\n", " return {\"chunks\": inputs}\n", "\n", "\n", "def tokenize(examples):\n", " outputs = {\n", " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n", " }\n", " outputs[\"labels\"] = outputs[\"input_ids\"]\n", " return outputs\n", "\n", "\n", "tokenized_datasets = (\n", " dataset\n", " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n", " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n", ")\n", "\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "train_dataloader = DataLoader(\n", " train_dataset.select(list(range(NUM_SAMPLES))),\n", " shuffle=True,\n", " batch_size=BATCH_SIZE,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": {}, "source": [ "Before setting up optimizers, check model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": {}, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": {}, "source": [ "Optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": {}, "source": [ "Let's initialize wandb for logging and start training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": {}, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-personachat\",\n", " config={\n", " \"num_samples\": NUM_SAMPLES,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "progress_bar = tqdm(range(len(train_dataloader)))\n", "for batch in train_dataloader:\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " model.train()\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss})\n", " progress_bar.update(1)" ] }, { "cell_type": "markdown", "id": "0f36cb80", "metadata": {}, "source": [ "Try to talk with the trained model! To break from dialog mode, press Ctrl+D." ] }, { "cell_type": "code", "execution_count": null, "id": "720181b7", "metadata": {}, "outputs": [], "source": [ "MAX_TOKENS = 16\n", "TOP_K = 100\n", "TEMPERATURE = 0.6\n", "dialog = \"\"\n", "\n", "while True:\n", " user_phrase = input()\n", " if len(user_phrase) == 0:\n", " break\n", " dialog += f\"{user_phrase}\\n-----\\n\"\n", " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n", " outputs = model.generate(\n", " inputs,\n", " temperature=TEMPERATURE,\n", " do_sample=True,\n", " top_k=TOP_K,\n", " eos_token_id=tokenizer.eos_token_id,\n", " max_new_tokens=MAX_TOKENS,\n", " )\n", " bloom_answer = tokenizer.batch_decode(outputs)[0]\n", " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n", " print(bloom_answer)\n", " dialog += f\"{bloom_answer}\\n-----\\n\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }