|
@@ -0,0 +1,317 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "a07e0f5e",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "<div>\n",
|
|
|
+ "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
|
|
|
+ "</div>\n",
|
|
|
+ "\n",
|
|
|
+ "# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
|
|
+ "\n",
|
|
|
+ "In this example, we showcase how the Bloom model can be efficiently adapted in a decentralized fashion. In particular, servers maintain the Bloom transformer, which is kept unchanged during adaptation, and learn only a few prefix tokens.\n",
|
|
|
+ "\n",
|
|
|
+ "This example will train the Bloom model for chatbot task. On a given dialog context the model have to provide a relevant answer. [Link to dataset](https://huggingface.co/datasets/bavard/personachat_truecased)."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "a3f8526f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Firslt, we have to prepare all dependicies."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "b4ab6ca7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "%env OMP_NUM_THREADS=24\n",
|
|
|
+ "\n",
|
|
|
+ "import os\n",
|
|
|
+ "import sys\n",
|
|
|
+ "import inspect\n",
|
|
|
+ "sys.path.insert(0, \"..\")\n",
|
|
|
+ "\n",
|
|
|
+ "# General \n",
|
|
|
+ "import torch\n",
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "from tqdm import tqdm\n",
|
|
|
+ "from torch.optim import AdamW\n",
|
|
|
+ "from torch.utils.data import DataLoader\n",
|
|
|
+ "\n",
|
|
|
+ "# Distributed\n",
|
|
|
+ "import hivemind\n",
|
|
|
+ "from src.client.remote_model import DistributedBloomForCausalLM\n",
|
|
|
+ "\n",
|
|
|
+ "# HF imports\n",
|
|
|
+ "import transformers\n",
|
|
|
+ "import wandb\n",
|
|
|
+ "from datasets import load_dataset\n",
|
|
|
+ "from transformers import get_scheduler\n",
|
|
|
+ "\n",
|
|
|
+ "# Visualization dependencies\n",
|
|
|
+ "from IPython.display import clear_output"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "1bf07b5d",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Set some hyperparameters for training. To setup petals servers, please read \\<link here\\> or use public available one \\<link here\\>."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "f04ba4d2",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "MODEL_NAME = ... # select model you like\n",
|
|
|
+ "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
|
|
|
+ "NUM_PREFIX_TOKENS = 16\n",
|
|
|
+ "DEVICE ='cpu'\n",
|
|
|
+ "BATCH_SIZE = 4\n",
|
|
|
+ "LR = 1e-2\n",
|
|
|
+ "WEIGHT_DECAY = 0.0\n",
|
|
|
+ "NUM_SAMPLES = 1000\n",
|
|
|
+ "SEED = 42\n",
|
|
|
+ "MODEL_MAX_LENGTH = 256\n",
|
|
|
+ "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "d38316bd",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Prepare tokenizer and distributed model, connect it to servers."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "03c6e53e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
|
|
+ "tokenizer.padding_side = 'right'\n",
|
|
|
+ "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
|
|
+ "model = DistributedBloomForCausalLM.from_pretrained(\n",
|
|
|
+ " MODEL_NAME, \n",
|
|
|
+ " initial_peers=INITIAL_PEERS, \n",
|
|
|
+ " pre_seq_len=NUM_PREFIX_TOKENS, \n",
|
|
|
+ " tuning_mode=TUNING_MODE\n",
|
|
|
+ ").to(DEVICE)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "042e3786",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Prepare personachat dataset. We need two mapping function, one to concatinating history and candidate answers, another for tokenization."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "9c44d516",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "dataset = load_dataset(\"bavard/personachat_truecased\")\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def chunking(examples):\n",
|
|
|
+ " inputs = [\n",
|
|
|
+ " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
|
|
|
+ " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
|
|
|
+ " for candidate in candidates\n",
|
|
|
+ " ]\n",
|
|
|
+ " return {\"chunks\": inputs}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def tokenize(examples):\n",
|
|
|
+ " outputs = {\n",
|
|
|
+ " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
|
|
|
+ " }\n",
|
|
|
+ " outputs[\"labels\"] = outputs[\"input_ids\"]\n",
|
|
|
+ " return outputs\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "tokenized_datasets = (\n",
|
|
|
+ " dataset\n",
|
|
|
+ " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
|
|
|
+ " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "tokenized_datasets.set_format(\"torch\")\n",
|
|
|
+ "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
|
|
|
+ "train_dataloader = DataLoader(\n",
|
|
|
+ " train_dataset.select(list(range(NUM_SAMPLES))),\n",
|
|
|
+ " shuffle=True,\n",
|
|
|
+ " batch_size=BATCH_SIZE,\n",
|
|
|
+ " drop_last=True,\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "ef4323fd",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Before setting up optimizers, check model parameters that will be trained."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "9cc0ba34",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "for n, p in model.named_parameters():\n",
|
|
|
+ " if p.requires_grad:\n",
|
|
|
+ " print(n, p.requires_grad, p.device)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "59cffce7",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Optimizer will only work on **prompts**, they are only trainable parameters. So initialize optimizer and learning rate scheduler."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "ef9bf344",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
|
+ "\n",
|
|
|
+ "lr_scheduler = get_scheduler(\n",
|
|
|
+ " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "423c56d5",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Let's initialize wandb for logging and start training loop!"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "d9e46807",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "wandb.init(\n",
|
|
|
+ " project=\"bloom-personachat\",\n",
|
|
|
+ " config={\n",
|
|
|
+ " \"num_samples\": NUM_SAMPLES,\n",
|
|
|
+ " \"batch_size\": BATCH_SIZE,\n",
|
|
|
+ " \"learning_rate\": LR,\n",
|
|
|
+ " \"weight_decay\": WEIGHT_DECAY,\n",
|
|
|
+ " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
|
|
|
+ " \"model_name\": MODEL_NAME,\n",
|
|
|
+ " \"seed\": SEED,\n",
|
|
|
+ " }\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "progress_bar = tqdm(range(len(train_dataloader)))\n",
|
|
|
+ "for batch in train_dataloader:\n",
|
|
|
+ " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
|
+ "\n",
|
|
|
+ " model.train()\n",
|
|
|
+ " outputs = model(**batch)\n",
|
|
|
+ " loss = outputs.loss\n",
|
|
|
+ " loss.backward()\n",
|
|
|
+ "\n",
|
|
|
+ " optimizer.step()\n",
|
|
|
+ " lr_scheduler.step()\n",
|
|
|
+ " optimizer.zero_grad()\n",
|
|
|
+ "\n",
|
|
|
+ " wandb.log({\"Train Loss\": loss})\n",
|
|
|
+ " progress_bar.update(1)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "0f36cb80",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Try to talk with the trained model! To break from dialog mode, press Ctrl+D."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "720181b7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "MAX_TOKENS = 16\n",
|
|
|
+ "TOP_K = 100\n",
|
|
|
+ "TEMPERATURE = 0.6\n",
|
|
|
+ "dialog = \"\"\n",
|
|
|
+ "\n",
|
|
|
+ "while True:\n",
|
|
|
+ " user_phrase = input()\n",
|
|
|
+ " if len(user_phrase) == 0:\n",
|
|
|
+ " break\n",
|
|
|
+ " dialog += f\"{user_phrase}\\n-----\\n\"\n",
|
|
|
+ " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
|
|
|
+ " outputs = model.generate(\n",
|
|
|
+ " inputs,\n",
|
|
|
+ " temperature=TEMPERATURE,\n",
|
|
|
+ " do_sample=True,\n",
|
|
|
+ " top_k=TOP_K,\n",
|
|
|
+ " eos_token_id=tokenizer.eos_token_id,\n",
|
|
|
+ " max_new_tokens=MAX_TOKENS,\n",
|
|
|
+ " )\n",
|
|
|
+ " bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
|
|
|
+ " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
|
|
|
+ " print(bloom_answer)\n",
|
|
|
+ " dialog += f\"{bloom_answer}\\n-----\\n\""
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.12"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|