Browse Source

Don't expose BloomTokenizerFast

Aleksandr Borzunov 2 years ago
parent
commit
654aa644de

+ 2 - 1
README.md

@@ -140,7 +140,8 @@ Once your have enough servers, you can use them to train and/or inference the mo
 ```python
 import torch
 import torch.nn.functional as F
-from petals.client import BloomTokenizerFast, DistributedBloomForCausalLM
+from transformers import BloomTokenizerFast
+from petals.client import DistributedBloomForCausalLM
 
 initial_peers = [TODO_put_one_or_more_server_addresses_here]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
 tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")

+ 2 - 3
examples/prompt-tuning-personachat.ipynb

@@ -55,7 +55,6 @@
    "outputs": [],
    "source": [
     "import os\n",
-    "import sys\n",
     " \n",
     "import torch\n",
     "import transformers\n",
@@ -64,7 +63,7 @@
     "from tqdm import tqdm\n",
     "from torch.optim import AdamW\n",
     "from torch.utils.data import DataLoader\n",
-    "from transformers import get_scheduler\n",
+    "from transformers import BloomTokenizerFast, get_scheduler\n",
     "\n",
     "# Import a Petals model\n",
     "from petals.client.remote_model import DistributedBloomForCausalLM"
@@ -113,7 +112,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+    "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
     "tokenizer.padding_side = 'right'\n",
     "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
     "model = DistributedBloomForCausalLM.from_pretrained(\n",

+ 2 - 3
examples/prompt-tuning-sst2.ipynb

@@ -55,7 +55,6 @@
    "outputs": [],
    "source": [
     "import os\n",
-    "import sys\n",
     " \n",
     "import torch\n",
     "import transformers\n",
@@ -64,7 +63,7 @@
     "from tqdm import tqdm\n",
     "from torch.optim import AdamW\n",
     "from torch.utils.data import DataLoader\n",
-    "from transformers import get_scheduler\n",
+    "from transformers import BloomTokenizerFast, get_scheduler\n",
     "\n",
     "# Import a Petals model\n",
     "from petals.client.remote_model import DistributedBloomForSequenceClassification"
@@ -114,7 +113,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+    "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
     "tokenizer.padding_side = 'right'\n",
     "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
     "model = DistributedBloomForSequenceClassification.from_pretrained(\n",

+ 0 - 2
src/petals/client/__init__.py

@@ -1,5 +1,3 @@
-from transformers import BloomTokenizerFast
-
 from petals.client.inference_session import InferenceSession
 from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock