Przeglądaj źródła

Fix issues related to `petals` as a module (#159)

1. Added `from petals.client import *` to `petals/__init__.py`, so you can write just that:

    ```python
    from petals import DistributedBloomForCausalLM
    ```

    I didn't do the same with server, since its classes are supposed to by used by `petals.cli.run_server`, not end-users. Though it's still possible to do `from petals.server.smth import smth` if necessary.

2. Fixed one more logging issue: log lines from hivemind were shown twice due to a bug in #156.

3. Removed unused `runtime.py`, since the server actually uses `hivemind.moe.Runtime`, and `runtime.py` has no significant changes comparing to it.
Alexander Borzunov 2 lat temu
rodzic
commit
523a7cad33

+ 5 - 5
README.md

@@ -8,7 +8,7 @@
 Generate text using distributed BLOOM and fine-tune it for your own tasks:
 
 ```python
-from petals.client import DistributedBloomForCausalLM
+from petals import DistributedBloomForCausalLM
 
 # Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet
 model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune")
@@ -68,13 +68,13 @@ Check out more tutorials:
     📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
 </p>
 
-### 📋 Model's terms of use
+### 🔒 Privacy and security
 
-Before building your own application that runs a language model with Petals, please make sure that you are familiar with the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
+The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
-### 🔒 Privacy and security
+### 📋 Model's terms of use
 
-**If you work with sensitive data, do not use the public swarm.** This is important because it's technically possible for peers serving model layers to recover input data and model outputs, or modify the outputs in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process this data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
+Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
 
 ## FAQ
 

+ 2 - 3
examples/prompt-tuning-personachat.ipynb

@@ -48,7 +48,7 @@
    "outputs": [],
    "source": [
     "import os\n",
-    " \n",
+    "\n",
     "import torch\n",
     "import transformers\n",
     "import wandb\n",
@@ -58,8 +58,7 @@
     "from torch.utils.data import DataLoader\n",
     "from transformers import BloomTokenizerFast, get_scheduler\n",
     "\n",
-    "# Import a Petals model\n",
-    "from petals.client.remote_model import DistributedBloomForCausalLM"
+    "from petals import DistributedBloomForCausalLM"
    ]
   },
   {

+ 4 - 7
examples/prompt-tuning-sst2.ipynb

@@ -48,22 +48,19 @@
    "outputs": [],
    "source": [
     "import os\n",
-    " \n",
-    "import torch\n",
-    "import transformers\n",
-    "import wandb\n",
     "\n",
+    "import torch\n",
     "import torch.nn as nn\n",
     "import torch.nn.functional as F\n",
-    "\n",
+    "import transformers\n",
+    "import wandb\n",
     "from datasets import load_dataset, load_metric\n",
     "from tqdm import tqdm\n",
     "from torch.optim import AdamW\n",
     "from torch.utils.data import DataLoader\n",
     "from transformers import BloomTokenizerFast, get_scheduler\n",
     "\n",
-    "# Import a Petals model\n",
-    "from petals.client.remote_model import DistributedBloomForSequenceClassification"
+    "from petals import DistributedBloomForSequenceClassification"
    ]
   },
   {

+ 3 - 2
src/petals/__init__.py

@@ -1,5 +1,6 @@
-import petals.utils.logging
+from petals.client import *
+from petals.utils.logging import initialize_logs as _initialize_logs
 
 __version__ = "1.0alpha1"
 
-petals.utils.logging.initialize_logs()
+_initialize_logs()

+ 6 - 1
src/petals/client/__init__.py

@@ -1,5 +1,10 @@
 from petals.client.inference_session import InferenceSession
-from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
+from petals.client.remote_model import (
+    DistributedBloomConfig,
+    DistributedBloomForCausalLM,
+    DistributedBloomForSequenceClassification,
+    DistributedBloomModel,
+)
 from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 0 - 198
src/petals/server/runtime.py

@@ -1,198 +0,0 @@
-import multiprocessing as mp
-import multiprocessing.pool
-import threading
-from collections import defaultdict
-from itertools import chain
-from queue import SimpleQueue
-from selectors import EVENT_READ, DefaultSelector
-from statistics import mean
-from time import time
-from typing import Dict, NamedTuple, Optional
-
-import torch
-from hivemind.moe.server.module_backend import ModuleBackend
-from hivemind.utils import get_logger
-from prefetch_generator import BackgroundGenerator
-
-logger = get_logger(__name__)
-
-
-class Runtime(threading.Thread):
-    """
-    A group of processes that processes incoming requests for multiple module backends on a shared device.
-    Runtime is usually created and managed by Server, humans need not apply.
-
-    For debugging, you can start runtime manually with .start() or .run()
-
-    >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
-    >>> runtime = Runtime(module_backends)
-    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
-    >>> runtime.ready.wait()  # await for runtime to load all blocks on device and create request pools
-    >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
-    >>> print("Returned:", future.result())
-    >>> runtime.shutdown()
-
-    :param module_backends: a dict [block uid -> ModuleBackend]
-    :param prefetch_batches: form up to this many batches in advance
-    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
-    :param device: if specified, moves all blocks and data to this device via .to(device=device).
-      If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
-
-    :param stats_report_interval: interval to collect and log statistics about runtime performance
-    """
-
-    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
-
-    def __init__(
-        self,
-        module_backends: Dict[str, ModuleBackend],
-        prefetch_batches: int = 1,
-        sender_threads: int = 1,
-        device: torch.device = None,
-        stats_report_interval: Optional[int] = None,
-    ):
-        super().__init__()
-        self.module_backends = module_backends
-        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
-        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
-        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
-        self.shutdown_trigger = mp.Event()
-        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
-
-        self.stats_report_interval = stats_report_interval
-        if self.stats_report_interval is not None:
-            self.stats_reporter = StatsReporter(self.stats_report_interval)
-
-    def run(self):
-        for pool in self.pools:
-            if not pool.is_alive():
-                pool.start()
-        if self.device is not None:
-            for backend in self.module_backends.values():
-                backend.module.to(self.device)
-
-        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
-            try:
-                self.ready.set()
-                if self.stats_report_interval is not None:
-                    self.stats_reporter.start()
-                logger.info("Started")
-
-                batch_iterator = self.iterate_minibatches_from_pools()
-                if self.prefetch_batches > 0:
-                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
-
-                for pool, batch_index, batch in batch_iterator:
-                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
-
-                    start = time()
-                    try:
-                        outputs = pool.process_func(*batch)
-                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
-
-                        batch_processing_time = time() - start
-
-                        batch_size = outputs[0].size(0)
-                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
-
-                        if self.stats_report_interval is not None:
-                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
-
-                    except KeyboardInterrupt:
-                        raise
-                    except BaseException as exception:
-                        logger.exception(f"Caught {exception}, attempting to recover")
-                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
-
-            finally:
-                if not self.shutdown_trigger.is_set():
-                    self.shutdown()
-
-    def shutdown(self):
-        """Gracefully terminate a running runtime."""
-        logger.info("Shutting down")
-        self.ready.clear()
-
-        if self.stats_report_interval is not None:
-            self.stats_reporter.stop.set()
-            self.stats_reporter.join()
-
-        logger.debug("Terminating pools")
-        for pool in self.pools:
-            if pool.is_alive():
-                pool.shutdown()
-        logger.debug("Pools terminated")
-
-        # trigger background thread to shutdown
-        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
-        self.shutdown_trigger.set()
-
-    def iterate_minibatches_from_pools(self, timeout=None):
-        """
-        Chooses pool according to priority, then copies exposed batch and frees the buffer
-        """
-        with DefaultSelector() as selector:
-            for pool in self.pools:
-                selector.register(pool.batch_receiver, EVENT_READ, pool)
-            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
-
-            while True:
-                # wait until at least one batch_receiver becomes available
-                logger.debug("Waiting for inputs from task pools")
-                ready_fds = selector.select()
-                ready_objects = {key.data for (key, events) in ready_fds}
-                if self.SHUTDOWN_TRIGGER in ready_objects:
-                    break  # someone asked us to shutdown, break from the loop
-
-                logger.debug("Choosing the pool with first priority")
-
-                pool = min(ready_objects, key=lambda pool: pool.priority)
-
-                logger.debug(f"Loading batch from {pool.name}")
-                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
-                logger.debug(f"Loaded batch from {pool.name}")
-                yield pool, batch_index, batch_tensors
-
-
-BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
-
-
-class StatsReporter(threading.Thread):
-    def __init__(self, report_interval: int):
-        super().__init__()
-        self.report_interval = report_interval
-        self.stop = threading.Event()
-        self.stats_queue = SimpleQueue()
-
-    def run(self):
-        while not self.stop.wait(self.report_interval):
-            pool_batch_stats = defaultdict(list)
-            while not self.stats_queue.empty():
-                pool_uid, batch_stats = self.stats_queue.get()
-                pool_batch_stats[pool_uid].append(batch_stats)
-
-            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
-            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
-            for pool_uid, pool_stats in pool_batch_stats.items():
-                total_batches = len(pool_stats)
-                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
-                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
-                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
-                batches_to_time = total_batches / total_time
-                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
-
-                examples_to_time = total_examples / total_time
-                example_performance = f"{examples_to_time:.2f} " + (
-                    "examples/s" if examples_to_time > 1 else "s/example"
-                )
-
-                logger.info(
-                    f"{pool_uid}: "
-                    f"{total_batches} batches ({batch_performance}), "
-                    f"{total_examples} examples ({example_performance}), "
-                    f"avg batch size {avg_batch_size:.2f}"
-                )
-
-    def report_stats(self, pool_uid, batch_size, processing_time):
-        batch_stats = BatchStats(batch_size, processing_time)
-        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 4 - 1
src/petals/utils/logging.py

@@ -25,7 +25,10 @@ def initialize_logs():
         os.environ["HIVEMIND_COLORS"] = "True"
     importlib.reload(hm_logging)
 
-    hm_logging.get_logger().handlers.clear()  # Remove extra default handlers on Colab
+    # Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab
+    hm_logging.get_logger().handlers.clear()
+    hm_logging.get_logger("hivemind").handlers.clear()
+
     hm_logging.use_hivemind_log_handler("in_root_logger")
 
     # We suppress asyncio error logs by default since they are mostly not relevant for the end user,

+ 1 - 1
tests/test_priority_pool.py

@@ -3,8 +3,8 @@ import time
 
 import pytest
 import torch
+from hivemind.moe.server.runtime import Runtime
 
-from petals.server.runtime import Runtime
 from petals.server.task_pool import PrioritizedTaskPool