Просмотр исходного кода

Add CollaborativeOptimizer, TrainingAverager (#215)

* Implemented CollaborativeOptimizer - a training tool that accumulates a large shared batch over decentralized peers and performs optimizer steps over that batch
* Implemented TrainingAverager - a wrapper over DecentralizedAverager that handles standard training scenarios

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Roman Zhytar <roma213423@gmail.com>
Alexey Bukhtiyarov 4 лет назад
Родитель
Сommit
7bb6565674

+ 1 - 1
hivemind/client/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.optim import ParameterAveragingOptimizer, DecentralizedSGD
+from hivemind.client.optim import ParameterAveragingOptimizer, DecentralizedSGD, CollaborativeOptimizer

+ 157 - 0
hivemind/client/averaging/training.py

@@ -0,0 +1,157 @@
+""" An extension of averager that supports common optimization use cases. """
+from itertools import chain
+from threading import Lock
+from typing import Sequence, Dict, Iterator
+
+import torch
+
+from hivemind.client.averaging import DecentralizedAverager
+from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
+
+logger = get_logger(__name__)
+
+
+class TrainingAverager(DecentralizedAverager):
+    """
+    A high-level interface to DecentralizedAverager that averages trainable params or gradients for an optimizer.
+
+    This averager implements a number of typical use cases that arise in collaborative optimization
+    - averaging parameters or gradients or both (in future, this will support averaging learning rates as well)
+    - this peer's weight (e.g. based on its batch size) can be specified via averager.step(weight=...)
+    - when out of sync, the averager will load the entire optimizer state from an up-to-date peer
+
+    :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
+    :param average_parameters: whether or not to average model parameters in self.step(...)
+    :param average_gradients: whether or not to average model gradients in self.step(...)
+    :param initialize_optimizer: if True, this will run a speculative optimizer step with
+      zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+    def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
+                 extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
+
+        self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
+        self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.lock_averager_step = Lock()
+        if initialize_optimizer:
+            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
+
+        with torch.no_grad():
+            averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
+        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+    @torch.no_grad()
+    def step(self, wait: bool = True, **kwargs):
+        """ Average optimizer weights and gradients with peers. """
+        if not wait:
+            return run_in_background(self.step, wait=False, **kwargs)
+
+        local_tensors = list(self.local_tensors())
+        with self.lock_averager_step:
+            # fill averager's tensors with current local tensors, scaled by peer's weight
+            with self.get_tensors() as averaged_tensors:
+                assert len(local_tensors) == len(averaged_tensors)
+                for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
+                    averaged_tensor[...] = local_tensor.detach().cpu().float()
+
+            # find a group and hopefully average tensors with peers
+            gathered = super().step(**kwargs)
+
+            # load averaged tensors back into model
+            with self.get_tensors() as averaged_tensors:
+                assert len(averaged_tensors) == len(local_tensors)
+                for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
+                    local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+
+            self.local_step += 1
+            return gathered
+
+    def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
+        """
+        Iterate local trainer's tensors that should be averaged with peers
+
+        :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
+          Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
+        """
+        if self.average_parameters:
+            for param_group in self.opt.param_groups:
+                yield from param_group['params']
+        if self.average_gradients:
+            for param_group in self.opt.param_groups:
+                for param in param_group['params']:
+                    if param.grad is not None:
+                        yield param.grad
+                    elif replace_none:
+                        yield torch.zeros_like(param)
+        yield from iter(self.extra_tensors)
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad():
+            optimized_parameters = tuple(param.detach().cpu() for param_group in self.opt.param_groups
+                                         for param in param_group['params'])
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+
+        metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
+        return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        parameters_and_extras = [param for param_group in self.opt.param_groups for param in param_group['params']]
+        parameters_and_extras.extend(self.extra_tensors)
+        num_local_tensors = len(parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+        metadata, flat_tensors = loaded_state
+        loaded_parameters_and_extras = flat_tensors[:num_local_tensors]
+        loaded_opt_tensors = flat_tensors[num_local_tensors:]
+
+        with torch.no_grad():
+            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+                local_param[...] = loaded_param
+            load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
+
+        self.local_step = max(self.local_step, metadata['step'])
+
+
+def initialize_optimizer_state(opt: torch.optim.Optimizer):
+    for param_group in opt.param_groups:
+        for param in param_group['params']:
+            if param.grad is None:
+                (0 * param.sum()).backward()
+    opt.step()
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type='tensor', index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type='value', value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get('type') == 'tensor' and isinstance(elem.get('index'), int):
+            flat_optimizer_state.append(flat_tensors[elem['index']])
+        elif elem.get('type') == 'value' and 'value' in elem:
+            flat_optimizer_state.append(elem['value'])
+    with torch.no_grad():
+        return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 1 - 0
hivemind/client/optim/__init__.py

@@ -1 +1,2 @@
 from hivemind.client.optim.simple import ParameterAveragingOptimizer, DecentralizedSGD
+from hivemind.client.optim.collaborative import CollaborativeOptimizer

+ 284 - 0
hivemind/client/optim/collaborative.py

@@ -0,0 +1,284 @@
+from __future__ import annotations
+import warnings
+from dataclasses import dataclass
+from threading import Thread, Lock, Event
+from typing import Optional, Type
+import logging
+
+import torch
+import numpy as np
+
+from hivemind.dht import DHT
+from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.client.averaging.training import TrainingAverager
+from hivemind.utils import get_logger, get_dht_time, run_in_background, ValueWithExpiration
+from hivemind.client.optim.performance_ema import PerformanceEMA
+
+logger = get_logger(__name__)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+
+
+@dataclass(frozen=False)
+class CollaborationState:
+    optimizer_step: int
+    samples_accumulated: int
+    target_batch_size: int
+    num_peers: int
+    num_clients: int
+    eta_next_step: float
+    next_fetch_time: float
+
+    @property
+    def ready_for_step(self):
+        return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
+
+    def register_step(self):
+        self.optimizer_step += 1
+        self.samples_accumulated = 0
+        self.eta_next_step = float('inf')
+
+
+class CollaborativeOptimizer(DecentralizedOptimizerBase):
+    """
+    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
+
+    These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
+    Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
+
+    :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
+
+    - calling .step will periodially zero-out gradients w.r.t. model parameters after each step
+    - it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
+
+    :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
+    :param dht: a running hivemind.DHT daemon connected to other peers
+    :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
+    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
+    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
+    :param target_group_size: maximum group size for DecentralizedAverager's all-reduce
+    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
+    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
+    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
+    :param expected_drift_peers: assume that this many new peers can join between steps
+    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
+    :note: the expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
+      refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
+    :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
+    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
+    :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
+    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param scheduler: if specified, use this scheduler to update optimizer learning rate
+    :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
+      explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
+    """
+
+    def __init__(self, opt: torch.optim.Optimizer, *, dht: DHT, prefix: str, target_batch_size: int,
+                 batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
+                 min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
+                 expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
+                 metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
+                 **kwargs):
+        super().__init__(opt, dht)
+        self.prefix, self.scheduler = prefix, scheduler
+        self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
+        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
+            min_refresh_period, max_refresh_period, default_refresh_period
+        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.averager = self._make_averager(**kwargs)
+
+        self.training_progress_key = f"{self.prefix}_progress"
+        self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
+        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
+        self.last_step_time = None
+
+        self.collaboration_state = self.fetch_collaboration_state()
+        self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
+        self.lock_local_progress, self.should_report_progress = Lock(), Event()
+        self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
+        self.progress_reporter.start()
+        self.collaboration_state_updater = Thread(target=self.check_collaboration_state_periodically, daemon=True,
+                                                  name=f"{self}.collaboration_state_updater")
+        self.collaboration_state_updater.start()
+
+    def _make_averager(self, **kwargs):
+        return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=True,
+                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout, **kwargs)
+
+    @property
+    def local_step(self) -> int:
+        return self.averager.local_step
+
+    @property
+    def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    def is_alive(self) -> bool:
+        return self.averager.is_alive()
+
+    def load_state_from_peers(self, **kwargs):
+        """ Attempt to fetch the newest collaboration state from other peers """
+        with self.lock_collaboration_state:
+            self.averager.load_state_from_peers(**kwargs)
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.update_scheduler()
+            self.opt.zero_grad()
+
+    def step(self, batch_size: Optional[int] = None, **kwargs):
+        """
+        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
+
+        :param batch_size: optional override for batch_size_per_step from init
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if batch_size is not None and self.batch_size_per_step is None:
+            raise ValueError("Please either set batch_size_per_step parameter at init or provide batch_size in .step")
+        batch_size = self.batch_size_per_step if batch_size is None else batch_size
+
+        if not self.is_synchronized:
+            self.load_state_from_peers()
+            return
+
+        if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
+            logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
+                           f"but metadata expired in {self.metadata_expiration} s.")
+
+        with self.lock_local_progress:
+            self.local_samples_accumulated += batch_size
+            self.local_steps_accumulated += 1
+            self.performance_ema.update(num_processed=self.batch_size_per_step)
+            self.should_report_progress.set()
+
+        if not self.collaboration_state.ready_for_step:
+            return
+
+        logger.log(self.status_loglevel, "Averaging parameters and gradients with peers...")
+        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state_updated.set()
+
+        if not self.is_synchronized:
+            self.load_state_from_peers()
+            return
+
+        with self.performance_ema.pause(), self.lock_collaboration_state:
+            if self.collaboration_state.num_peers > 1:
+                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
+                weight = self.local_samples_accumulated / mean_samples_per_worker
+                output = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
+                                                 f"{self.collaboration_state.num_peers} peer(s).")
+                output = None
+                self.averager.local_step += 1
+
+            self.opt.step()
+            self.opt.zero_grad()
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.collaboration_state.register_step()
+            self.collaboration_state_updated.set()
+            self.update_scheduler()
+
+            logger.log(self.status_loglevel, f"Optimizer step: done!")
+            return output
+
+    def report_training_progress(self):
+        """ Periodically publish metadata and the current number of samples accumulated towards the next step """
+        while self.is_alive():
+            self.should_report_progress.wait()
+            self.should_report_progress.clear()
+            with self.lock_local_progress:
+                current_time = get_dht_time()
+                local_state_info = [self.local_step, self.local_samples_accumulated,
+                                    self.performance_ema.samples_per_second, current_time, not self.averager.listen]
+
+            assert self.is_valid_peer_state(local_state_info), local_state_info
+            self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=local_state_info,
+                           expiration_time=current_time + self.metadata_expiration, return_future=True)
+
+    def check_collaboration_state_periodically(self):
+        """
+        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
+        """
+        while self.is_alive():
+            time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
+            if self.collaboration_state_updated.wait(time_to_next_update):
+                self.collaboration_state_updated.clear()
+                continue  # if state was updated externally, reset timer
+
+            with self.lock_collaboration_state:
+                self.collaboration_state = self.fetch_collaboration_state()
+
+    def fetch_collaboration_state(self) -> CollaborationState:
+        """ Read performance statistics reported by peers, estimate progress towards next batch """
+        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float('inf'))
+        current_time = get_dht_time()
+
+        if not isinstance(response, dict) or len(response) == 0:
+            logger.log(self.status_loglevel, f"Found no active peers: {response}")
+            local_eta_next_step = max(0, self.target_batch_size - self.local_steps_accumulated
+                                      ) / self.performance_ema.samples_per_second
+            return CollaborationState(self.local_step, self.local_samples_accumulated, self.target_batch_size,
+                                      num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
+                                      next_fetch_time=current_time + self.default_refresh_period)
+
+        valid_peer_states = [peer_state.value for peer_state in response.values()
+                             if isinstance(peer_state, ValueWithExpiration)
+                             and self.is_valid_peer_state(peer_state.value)]
+
+        num_peers = len(valid_peer_states)
+        num_clients = sum(is_client for *_, is_client in valid_peer_states)
+        global_optimizer_step = self.local_step
+        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
+            if not is_client:
+                global_optimizer_step = max(global_optimizer_step, opt_step)
+
+        total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0
+
+        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
+            total_samples_per_second += samples_per_second
+            if opt_step == global_optimizer_step:
+                total_samples_accumulated += samples_accumulated
+                estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
+            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
+            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
+
+        estimated_samples_remaining = self.target_batch_size - estimated_curent_samples
+        estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
+
+        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
+        time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers,
+                                           a_min=self.min_refresh_period, a_max=self.max_refresh_period))
+        logger.log(self.status_loglevel, f"Collaboration accumulated {total_samples_accumulated} samples from "
+                                         f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
+                                         f"(refresh in {time_to_next_fetch:.2f}s.)")
+        return CollaborationState(
+            global_optimizer_step, total_samples_accumulated, target_batch_size=self.target_batch_size,
+            num_peers=num_peers, num_clients=num_clients, eta_next_step=current_time + estimated_time_to_next_step,
+            next_fetch_time=current_time + time_to_next_fetch)
+
+    def zero_grad(self, *args, **kwargs):
+        warnings.warn("CollaborativeOptimizer.zero_grad is a no-op and doesn't need to be called")
+
+    @staticmethod
+    def is_valid_peer_state(state):
+        return isinstance(state, (list, tuple)) and len(state) == 5 \
+               and all(map(isinstance, state, (int, int, float, float, bool)))
+
+    def update_scheduler(self):
+        if self.scheduler:
+            while self.scheduler._step_count < self.local_step:
+                self.scheduler.step()
+
+    def shutdown(self):
+        logger.debug("Shutting down averager...")
+        self.averager.shutdown()
+        logger.debug("Sending goodbye to peers...")
+        self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=None,
+                       expiration_time=get_dht_time() + self.metadata_expiration)
+        logger.debug(f"{self.__class__.__name__} is shut down.")
+
+    def __del__(self):
+        self.shutdown()

+ 40 - 0
hivemind/client/optim/performance_ema.py

@@ -0,0 +1,40 @@
+from contextlib import contextmanager
+
+from hivemind.utils import get_dht_time
+
+
+class PerformanceEMA:
+    """
+    A running estimate of performance (operations/sec) using adjusted exponential moving average
+    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
+    """
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
+        self.alpha, self.eps, self.num_updates = alpha, eps, 0
+        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
+        self.timestamp = get_dht_time()
+        self.paused = False
+
+    def update(self, num_processed: int) -> float:
+        """
+        :param num_processed: how many items were processed since last call
+        :returns: current estimate of performance (samples per second), but at most
+        """
+        assert not self.paused, "PerformanceEMA is currently paused"
+        assert num_processed > 0, f"Can't register processing {num_processed} samples"
+        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
+        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
+        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
+        self.num_updates += 1
+        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
+        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
+        return self.samples_per_second
+
+    @contextmanager
+    def pause(self):
+        """ While inside this context, EMA will not count the time passed towards the performance estimate """
+        self.paused, was_paused = True, self.paused
+        try:
+            yield
+        finally:
+            self.timestamp = get_dht_time()
+            self.paused = was_paused