Pārlūkot izejas kodu

Merge branch 'decentralized_lr_scheduler' of https://github.com/learning-at-home/hivemind into decentralized_lr_scheduler

xtinkt 4 gadi atpakaļ
vecāks
revīzija
10b3775914

+ 150 - 0
examples/albert/arguments.py

@@ -0,0 +1,150 @@
+from typing import Optional, List
+from dataclasses import dataclass, field
+
+from transformers import TrainingArguments
+
+
+@dataclass
+class BaseTrainingArguments:
+    experiment_prefix: str = field(
+        metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
+    )
+    initial_peers: List[str] = field(
+        default_factory=list,
+        metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"}
+    )
+    dht_listen_on: str = field(
+        default="[::]:*",
+        metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
+    )
+
+
+@dataclass
+class AveragerArguments:
+    averaging_expiration: float = field(
+        default=5.0,
+        metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
+    )
+    averaging_timeout: float = field(
+        default=30.0,
+        metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    listen_on: str = field(
+        default="[::]:*",
+        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"}
+    )
+    min_refresh_period: float = field(
+        default=0.5,
+        metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
+    )
+    max_refresh_period: float = field(
+        default=30,
+        metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
+    )
+    default_refresh_period: float = field(
+        default=3,
+        metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
+    )
+    expected_drift_peers: float = field(
+        default=3,
+        metadata={"help": "Trainer assumes that this many new peers can join per step"}
+    )
+    expected_drift_rate: float = field(
+        default=0.2,
+        metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
+    )
+    performance_ema_alpha: float = field(
+        default=0.1,
+        metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
+    )
+    target_group_size: int = field(
+        default=256,
+        metadata={"help": "Maximum group size for all-reduce"}
+    )
+    metadata_expiration: float = field(
+        default=30,
+        metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+    )
+
+
+@dataclass
+class CollaborativeOptimizerArguments:
+    target_batch_size: int = field(
+        default=4096,
+        metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}
+    )
+    client_mode: bool = field(
+        default=False,
+        metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}
+    )
+    batch_size_lead: int = field(
+        default=0,
+        metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"}
+    )
+    bandwidth: float = field(
+        default=100.0,
+        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}
+    )
+    compression: str = field(
+        default="FLOAT16",
+        metadata={"help": "Use this compression when averaging parameters/gradients"}
+    )
+
+
+@dataclass
+class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
+    statistics_expiration: float = field(
+        default=600,
+        metadata={"help": "Statistics will be removed if not updated in this many seconds"}
+    )
+    endpoint: Optional[str] = field(
+        default=None,
+        metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"}
+    )
+
+
+@dataclass
+class DatasetArguments:
+    dataset_path: Optional[str] = field(
+        default='data/albert_tokenized_wikitext',
+        metadata={"help": "Path to the tokenized dataset"}
+    )
+    tokenizer_path: Optional[str] = field(
+        default='data/tokenizer',
+        metadata={"help": "Path to the tokenizer"}
+    )
+    config_path: Optional[str] = field(
+        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
+        metadata={"help": "Path to the model config"}
+    )
+    cache_dir: Optional[str] = field(
+        default='data',
+        metadata={"help": "Path to the cache"}
+    )
+
+
+@dataclass
+class AlbertTrainingArguments(TrainingArguments):
+    dataloader_num_workers: int = 4
+    per_device_train_batch_size: int = 4
+    per_device_eval_batch_size: int = 4
+    gradient_accumulation_steps: int = 2
+    seq_length: int = 512
+
+    max_steps: int = 1_000_000  # Albert is actually ready after 125000 steps
+    learning_rate: float = 0.00176
+    warmup_steps: int = 5000
+    adam_epsilon: float = 1e-6
+    weight_decay: float = 0.01
+    max_grad_norm: float = 1.0
+    clamp_value: float = 10000.0
+
+    fp16: bool = True
+    fp16_opt_level: str = 'O2'
+    do_train: bool = True
+
+    logging_steps: int = 100
+    save_total_limit: int = 2
+    save_steps: int = 500
+
+    output_dir: str = 'outputs'

+ 25 - 0
examples/albert/metrics_utils.py

@@ -0,0 +1,25 @@
+from typing import Dict, List, Tuple
+
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
+from hivemind.dht.validation import RecordValidatorBase
+from pydantic import BaseModel, StrictFloat, confloat, conint
+
+
+class LocalMetrics(BaseModel):
+    step: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    loss: StrictFloat
+    mini_steps: conint(ge=0, strict=True)
+
+
+class MetricSchema(BaseModel):
+    metrics: Dict[BytesWithPublicKey, LocalMetrics]
+
+
+def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+    signature_validator = RSASignatureValidator()
+    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
+                  signature_validator]
+    return validators, signature_validator.local_public_key

+ 159 - 36
examples/albert/run_first_peer.py

@@ -1,49 +1,167 @@
 #!/usr/bin/env python
 
+from dataclasses import dataclass, field, asdict
+import subprocess
 import time
-import argparse
+from typing import Optional
+
+import torch
+from torch_optimizer import Lamb
+from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 import wandb
 from whatsmyip.providers import GoogleDnsProvider
 from whatsmyip.ip import get_ip
 
+from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 import hivemind
 from hivemind.utils.logging import get_logger
+import metrics_utils
 
 
 logger = get_logger(__name__)
 
+
+@dataclass
+class CoordinatorArguments(BaseTrainingArguments):
+    """
+    Note: You might want to have several initial peers so that if one dies,
+    new workers still can join the collaboration via alive initial peers' addresses.
+    Specify initial_peers argument for that purpose
+    """
+    address: Optional[str] = field(
+        default=None,
+        metadata={"help": "This machine's network address. Use public IP for global experiments, "
+                          "local address for private runs"}
+    )
+    refresh_period: float = field(
+        default=30,
+        metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
+    )
+    wandb_project: Optional[str] = field(
+        default=None,
+        metadata={"help": "Learning curves will be published there"}
+    )
+    save_checkpoint_step_interval: int = field(
+        default=5,
+        metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
+    )
+    model_config_path: str = field(
+        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
+        metadata={"help": "Path to the model config"}
+    )
+    repo_path: Optional[str] = field(
+        default=None,
+        metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
+    )
+    upload_interval: Optional[float] = field(
+        default=None,
+        metadata={"help": "Coordinator will upload model once in this many seconds"}
+    )
+
+
+class CheckpointHandler:
+    def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args: CollaborativeOptimizerArguments,
+                 averager_args: AveragerArguments, dht: hivemind.DHT):
+        self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
+        self.repo_path = coordinator_args.repo_path
+        self.upload_interval = coordinator_args.upload_interval
+        self.previous_step = -1
+
+        config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
+        self.model = AlbertForPreTraining(config)
+
+        no_decay = ["bias", "LayerNorm.weight"]
+        optimizer_grouped_parameters = [
+            {
+                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+                "weight_decay": 0.01,
+            },
+            {
+                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+                "weight_decay": 0.0,
+            },
+        ]
+
+        opt = Lamb(
+            optimizer_grouped_parameters,
+            lr=0.00176, weight_decay=0.01, clamp_value=10000.0, debias=True,
+        )
+
+        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
+
+        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
+            opt=opt, dht=dht, prefix=experiment_prefix,
+            compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
+            throughput=collab_optimizer_args.bandwidth,
+            target_batch_size=adjusted_target_batch_size, client_mode=collab_optimizer_args.client_mode,
+            verbose=True, start=True, **asdict(averager_args)
+        )
+        self.previous_timestamp = time.time()
+
+    def is_time_to_save_state(self, cur_step):
+        if self.save_checkpoint_step_interval is None:
+            return False
+        elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
+            return True
+        else:
+            return False
+
+    def save_state(self, cur_step):
+        self.collaborative_optimizer.load_state_from_peers()
+        self.previous_step = cur_step
+
+    def is_time_to_upload(self):
+        if self.repo_path is None:
+            return False
+        elif time.time() - self.previous_timestamp >= self.upload_interval:
+            return True
+        else:
+            return False
+
+    def upload_checkpoint(self, current_loss):
+        self.model.save_pretrained(self.repo_path)
+        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        self.previous_timestamp = time.time()
+        try:
+            subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
+            current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
+            subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
+                           shell=True, check=True, cwd=self.repo_path)
+            subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
+        except subprocess.CalledProcessError as e:
+            logger.warning("Error while uploading model:", e.output)
+
+
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-
-    parser.add_argument('--address', type=str, required=False,
-                        help="this machine's network address. Use public IP for global experiments, "
-                             "local address for private runs.")
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
-    parser.add_argument('--refresh_period', type=float, default=30, required=False,
-                        help="coordinator will fetch keys from DHT once in this many seconds")
-    parser.add_argument('--experiment_prefix', type=str, required=True,
-                        help="a prefix where peers store their metrics for aggregation")
-    parser.add_argument('--wandb_project', type=str, required=True,
-                        help="Weights & Biases project name to publish learning curves")
-
-    args = parser.parse_args()
-    if args.address is None:
+    parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
+    coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+
+    if coordinator_args.address is None:
         logger.warning("No address specified. Attempting to infer address from DNS.")
-        args.address = get_ip(GoogleDnsProvider)
+        coordinator_args.address = get_ip(GoogleDnsProvider)
+
+    experiment_prefix = coordinator_args.experiment_prefix
+    validators, local_public_key = metrics_utils.make_validators(experiment_prefix)
+    dht = hivemind.DHT(start=True, listen_on=coordinator_args.dht_listen_on,
+                       endpoint=f"{coordinator_args.address}:*", initial_peers=coordinator_args.initial_peers,
+                       record_validators=validators)
 
-    dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*")
-    logger.info(f"Running DHT root at {args.address}:{dht.port}")
+    logger.info(f"Running DHT root at {coordinator_args.address}:{dht.port}")
+
+    if coordinator_args.wandb_project is not None:
+        wandb.init(project=coordinator_args.wandb_project)
 
-    wandb.init(project=args.wandb_project)
     current_step = 0
 
+    checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
+
     while True:
-        metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
+        metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
-            metrics = [metrics_dict[peer].value for peer in metrics_dict]
-            latest_step = max(metrics)[0]
+            metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
+                       for peer in metrics_dict]
+            latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
                 current_step = latest_step
                 alive_peers = 0
@@ -52,18 +170,23 @@ if __name__ == '__main__':
                 num_samples = 0
                 sum_perf = 0
                 sum_mini_steps = 0
-                for step, perf, samples, loss, mini_steps in metrics:
-                    sum_loss += loss
+                for item in metrics:
+                    sum_loss += item.loss
                     alive_peers += 1
-                    sum_perf += perf
-                    num_samples += samples
-                    sum_mini_steps += mini_steps
-                wandb.log({
-                    "loss": sum_loss / sum_mini_steps,
-                    "alive peers": alive_peers,
-                    "samples": num_samples,
-                    "performance": sum_perf
-                })
+                    sum_perf += item.samples_per_second
+                    num_samples += item.samples_accumulated
+                    sum_mini_steps += item.mini_steps
+                if coordinator_args.wandb_project is not None:
+                    wandb.log({
+                        "loss": sum_loss / sum_mini_steps,
+                        "alive peers": alive_peers,
+                        "samples": num_samples,
+                        "performance": sum_perf
+                    })
+                if checkpoint_handler.is_time_to_save_state(current_step):
+                    checkpoint_handler.save_state(current_step)
+                    if checkpoint_handler.is_time_to_upload():
+                        checkpoint_handler.upload_checkpoint(sum_loss / sum_mini_steps)
                 logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
         logger.debug("Peer is still alive...")
-        time.sleep(args.refresh_period)
+        time.sleep(coordinator_args.refresh_period)

+ 34 - 90
examples/albert/run_trainer.py

@@ -2,13 +2,13 @@
 
 import logging
 import os
-from dataclasses import dataclass, field, asdict
+from dataclasses import asdict
 from pathlib import Path
-from typing import Optional, Dict, Any, List
-import uuid
+from typing import Dict, Any
 
-from datasets import load_from_disk
+import torch
 import transformers
+from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from transformers import (set_seed, HfArgumentParser, TrainingArguments,
                           DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
@@ -16,84 +16,16 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer_utils import is_main_process
 from transformers.trainer import Trainer
 from torch_optimizer import Lamb
-import torch
 
 import hivemind
+from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments
+import metrics_utils
+
 
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
-@dataclass
-class CollaborationArguments:
-    """ define how peers interact with each other while training"""
-
-    # primary parameters
-    initial_peers: List[str]  # one or more peers (comma-separated) that will welcome you into the collaboration
-    experiment_prefix: str  # a unique "name" of this experiment, used to store metadata on the DHT
-    averaging_expiration: float = 5.0  # averaging group will wait for stragglers for at most this many seconds
-    averaging_timeout: float = 30.0  # give up on averaging step after this many seconds
-    target_batch_size: int = 4096  # perform optimizer step after all peers collectively accumulate this many samples
-    client_mode: bool = False  # if True, runs training without incoming connections, in a firewall-compatible mode
-    trainer_uuid: str = uuid.uuid4().hex  # this peer's name - used when publishing metadata to DHT, default = random
-
-    # optional tweaks
-    target_group_size: int = 256  # maximum group size for all-reduce
-    metadata_expiration: float = 30  # peer's metadata will be removed if not updated in this many seconds
-    statistics_expiration: float = 600  # statistics will be removed if not updated in this many seconds
-    dht_listen_on: str = '[::]:*'  # network interface used for incoming DHT communication. Default: all ipv6
-    listen_on: str = '[::]:*'  # network interface used for incoming averager communication. Default: all ipv6
-    endpoint: Optional[str] = None  # this node's IP for inbound connections, used when running from behind a proxy
-    batch_size_lead: int = 0  # optional: begin looking for group in advance, this many samples before target_batch_size
-    compression: str = 'FLOAT16'  # use this compression when averaging parameters/gradients
-
-    min_refresh_period: float = 0.5  # wait for at least this many seconds before fetching new collaboration state
-    max_refresh_period: float = 30  # wait for at most this many seconds before fetching new collaboration state
-    default_refresh_period: float = 3  # attempt to fetch collaboration state every this often until successful
-    expected_drift_peers: float = 3  # trainer assumes that this many new peers can join per step
-    expected_drift_rate: float = 0.2  # trainer assumes that this fraction of current size can join per step
-
-    bandwidth: float = 100.0  # available network bandwidth, in mbps (used for load balancing in all-reduce)
-    performance_ema_alpha: float = 0.1  # uses this alpha for moving average estimate of samples per second
-
-
-@dataclass
-class DatasetArguments:
-    dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
-                                        metadata={"help": "Path to the tokenized dataset"})
-    tokenizer_path: Optional[str] = field(default='./data/tokenizer',
-                                          metadata={"help": "Path to the tokenizer"})
-    config_path: Optional[str] = field(
-        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
-        metadata={"help": "Path to the model config"})
-    cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
-
-
-@dataclass
-class AlbertTrainingArguments(TrainingArguments):
-    dataloader_num_workers: int = 4
-    per_device_train_batch_size: int = 4
-    per_device_eval_batch_size: int = 4
-    gradient_accumulation_steps: int = 2
-    seq_length: int = 512
-
-    max_steps: int = 1_000_000  # Albert is actually ready after 125000 steps
-    learning_rate: float = 0.00176
-    warmup_steps: int = 5000
-    adam_epsilon: float = 1e-6
-    weight_decay: float = 0.01
-    max_grad_norm: float = 1.0
-    clamp_value: float = 10000.0
-
-    fp16: bool = True
-    fp16_opt_level: str = 'O2'
-    do_train: bool = True
-
-    logging_steps: int = 100
-    save_total_limit: int = 2
-    save_steps: int = 500
-
-
 def setup_logging(training_args):
     logging.basicConfig(
         format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
@@ -165,17 +97,23 @@ def get_optimizer_and_scheduler(training_args, model):
 
 class CollaborativeCallback(transformers.TrainerCallback):
     def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
-                 model: torch.nn.Module, trainer_uuid: str, statistics_expiration: float):
+                 model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float):
         super().__init__()
         self.model = model
         self.dht, self.collaborative_optimizer = dht, optimizer
-        self.trainer_uuid, self.statistics_expiration = trainer_uuid, statistics_expiration
+        self.local_public_key = local_public_key
+        self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
         self.previous_state = self.get_current_state()
         self.samples = 0
         self.steps = 0
         self.loss = 0
 
+    def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
+                       control: transformers.TrainerControl, **kwargs):
+        logger.warning('Loading state from peers')
+        self.collaborative_optimizer.load_state_from_peers()
+
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
                     control: transformers.TrainerControl, **kwargs):
         control.should_log = True
@@ -190,15 +128,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
 
-                statistics = [self.collaborative_optimizer.local_step,
-                              self.collaborative_optimizer.performance_ema.samples_per_second,
-                              self.samples,
-                              self.loss,
-                              self.steps]
+                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                statistics = metrics_utils.LocalMetrics(
+                    step=self.collaborative_optimizer.local_step,
+                    samples_per_second=samples_per_second,
+                    samples_accumulated=self.samples,
+                    loss=self.loss,
+                    mini_steps=self.steps)
                 self.loss = 0
                 self.steps = 0
-                self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
-                               value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
+                               subkey=self.local_public_key, value=statistics.dict(),
+                               expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
@@ -270,13 +211,15 @@ def main():
 
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
 
+    validators, local_public_key = metrics_utils.make_validators(
+        collaboration_args_dict['experiment_prefix'])
     dht = hivemind.DHT(
-        initial_peers=collaboration_args_dict.pop('initial_peers'),
-        listen=not collaboration_args_dict['client_mode'], listen_on=collaboration_args_dict.pop('dht_listen_on'),
-        endpoint=collaboration_args_dict.pop('endpoint'), start=True)
+        start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
+        listen=not collaboration_args_dict['client_mode'],
+        listen_on=collaboration_args_dict.pop('dht_listen_on'),
+        endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
-    trainer_uuid = collaboration_args_dict.pop('trainer_uuid')
     statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
     adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
                                  - collaboration_args_dict.pop('batch_size_lead')
@@ -292,7 +235,7 @@ def main():
     class TrainerWithIndependentShuffling(Trainer):
         def get_train_dataloader(self) -> DataLoader:
             """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
-            torch.manual_seed(hash(trainer_uuid))
+            torch.manual_seed(hash(local_public_key))
             return super().get_train_dataloader()
 
     trainer = TrainerWithIndependentShuffling(
@@ -300,7 +243,8 @@ def main():
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
-        callbacks=[CollaborativeCallback(dht, collaborative_optimizer, model, trainer_uuid, statistics_expiration)]
+        callbacks=[CollaborativeCallback(
+            dht, collaborative_optimizer, model, local_public_key, statistics_expiration)]
     )
     trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
     trainer.remove_callback(transformers.trainer_callback.ProgressCallback)

+ 47 - 33
hivemind/client/averaging/__init__.py

@@ -158,6 +158,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return f"{self.__class__.__name__}({self.endpoint})"
 
     def run(self):
+        """
+        Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
+        Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
+        Read more: https://github.com/pytorch/pytorch/issues/17199
+        """
+        thread = threading.Thread(target=self._run_internal, daemon=True)
+        thread.start()
+        thread.join()
+
+    def _run_internal(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
@@ -240,41 +250,45 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         start_time = get_dht_time()
         group_id = None
 
-        while not future.done():
-            try:
-                self._pending_group_assembled.clear()
-                data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
-                if group_info is None:
-                    raise AllreduceException("Averaging step failed: could not find a group.")
-                group_id = group_info.group_id
-                allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                self._running_groups[group_id] = allreduce_runner
-                self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
-
-                # averaging is finished, exit the loop
-                future.set_result(allreduce_runner.gathered)
-
-            except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
-                    asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
-                time_elapsed = get_dht_time() - start_time
-                if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    logger.exception(f"Averager caught {repr(e)}")
-                    future.set_exception(e)
-                else:
-                    logger.warning(f"Averager caught {repr(e)}, retrying")
+        try:
+            while not future.done():
+                try:
+                    self._pending_group_assembled.clear()
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    if group_info is None:
+                        raise AllreduceException("Averaging step failed: could not find a group.")
+                    group_id = group_info.group_id
+                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
+                    self._running_groups[group_id] = allreduce_runner
+                    self._pending_group_assembled.set()
+                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
+                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+
+                    # averaging is finished, exit the loop
+                    future.set_result(allreduce_runner.gathered)
+
+                except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
+                        asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
+                    time_elapsed = get_dht_time() - start_time
+                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                        logger.exception(f"Averager caught {repr(e)}")
+                        future.set_exception(e)
+                    else:
+                        logger.warning(f"Averager caught {repr(e)}, retrying")
 
-            except BaseException as e:
+                finally:
+                    _ = self._running_groups.pop(group_id, None)
+                    self._pending_group_assembled.set()
+
+        except BaseException as e:
+            if not future.done():
                 future.set_exception(e)
-                raise
-            finally:
-                _ = self._running_groups.pop(group_id, None)
-                self._pending_group_assembled.set()
-                if not future.done():
-                    future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
-                                                      " Please report this to hivemind issues."))
+            raise
+        finally:
+            if not future.done():
+                future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
+                                                  " Please report this to hivemind issues."))
 
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """

+ 5 - 0
hivemind/client/averaging/matchmaking.py

@@ -10,6 +10,7 @@ import concurrent.futures
 import asyncio
 
 import grpc
+import grpc._cython.cygrpc
 
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
@@ -199,6 +200,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if call is not None:
                 call.cancel()
             return None
+        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            return None
+
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None

+ 4 - 1
hivemind/client/averaging/training.py

@@ -196,4 +196,7 @@ def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict,
         elif elem.get('type') == 'value' and 'value' in elem:
             flat_optimizer_state.append(elem['value'])
     with torch.no_grad():
-        return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
+        try:
+            return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
+        except StopIteration:
+            return optimizer

+ 30 - 17
hivemind/dht/crypto.py

@@ -16,38 +16,50 @@ logger = get_logger(__name__)
 class RSASignatureValidator(RecordValidatorBase):
     """
     Introduces a notion of *protected records* whose key/subkey contains substring
-    "[owner:ssh-rsa ...]" (the format can be changed) with an RSA public key of the owner.
+    "[owner:ssh-rsa ...]" with an RSA public key of the owner.
 
     If this validator is used, changes to such records always must be signed with
     the corresponding private key (so only the owner can change them).
     """
 
-    def __init__(self,
-                 marker_format: bytes=b'[owner:_key_]',
-                 signature_format: bytes=b'[signature:_value_]'):
-        self._marker_re = re.compile(re.escape(marker_format).replace(b'_key_', rb'(.+?)'))
+    PUBLIC_KEY_FORMAT = b'[owner:_key_]'
+    SIGNATURE_FORMAT = b'[signature:_value_]'
 
-        self._signature_format = signature_format
-        self._signature_re = re.compile(re.escape(signature_format).replace(b'_value_', rb'(.+?)'))
+    PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b'_key_', rb'(.+?)')
+    _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
+    _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b'_value_', rb'(.+?)'))
 
-        self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+    _cached_private_key = None
+
+    def __init__(self, *, ignore_cached_key=False):
+        if self._cached_private_key is None or ignore_cached_key:
+            # Since generating a private key takes ~100 ms, we cache it for future validator
+            # instances in the same process (unless ignore_cached_key=True)
+            self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+            if not ignore_cached_key:
+                RSASignatureValidator._cached_private_key = self._private_key
+        else:
+            self._private_key = RSASignatureValidator._cached_private_key
 
         serialized_public_key = self._private_key.public_key().public_bytes(
             encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
-        self._ownership_marker = marker_format.replace(b'_key_', serialized_public_key)
+        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
+
+        self._init_signature_params()
 
+    def _init_signature_params(self) -> None:
         self._padding = padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                                     salt_length=padding.PSS.MAX_LENGTH)
         self._hash_algorithm = hashes.SHA256()
 
     @property
-    def ownership_marker(self) -> bytes:
-        return self._ownership_marker
+    def local_public_key(self) -> bytes:
+        return self._local_public_key
 
     def validate(self, record: DHTRecord) -> bool:
-        public_keys = self._marker_re.findall(record.key)
+        public_keys = self._PUBLIC_KEY_RE.findall(record.key)
         if record.subkey is not None:
-            public_keys += self._marker_re.findall(record.subkey)
+            public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
         if not public_keys:
             return True  # The record is not protected with a public key
 
@@ -56,7 +68,7 @@ class RSASignatureValidator(RecordValidatorBase):
             return False
         public_key = serialization.load_ssh_public_key(public_keys[0])
 
-        signatures = self._signature_re.findall(record.value)
+        signatures = self._SIGNATURE_RE.findall(record.value)
         if len(signatures) != 1:
             logger.debug(f"Record should have exactly one signature in {record}")
             return False
@@ -73,16 +85,16 @@ class RSASignatureValidator(RecordValidatorBase):
             return False
 
     def sign_value(self, record: DHTRecord) -> bytes:
-        if self._ownership_marker not in record.key and self._ownership_marker not in record.subkey:
+        if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
             return record.value
 
         signature = self._private_key.sign(self._serialize_record(record),
                                            self._padding, self._hash_algorithm)
         signature = base64.b64encode(signature)
-        return record.value + self._signature_format.replace(b'_value_', signature)
+        return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
 
     def strip_value(self, record: DHTRecord) -> bytes:
-        return self._signature_re.sub(b'', record.value)
+        return self._SIGNATURE_RE.sub(b'', record.value)
 
     def _serialize_record(self, record: DHTRecord) -> bytes:
         return MSGPackSerializer.dumps(dataclasses.astuple(record))
@@ -113,3 +125,4 @@ class RSASignatureValidator(RecordValidatorBase):
     def __setstate__(self, state):
         self.__dict__.update(state)
         self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._init_signature_params()

+ 48 - 42
hivemind/dht/schema.py

@@ -1,9 +1,11 @@
 import binascii
 import re
-from typing import Any, Dict, Type
+from contextlib import contextmanager
+from typing import Any, Dict, Optional, Type
 
 import pydantic
 
+from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTKey
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
@@ -19,7 +21,8 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
+    def __init__(self, schema: pydantic.BaseModel, *,
+                 allow_extra_keys: bool=True, prefix: Optional[str]=None):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 
@@ -28,24 +31,32 @@ class SchemaValidator(RecordValidatorBase):
             ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
             See the validate() docstring for details.
 
+            The model will be patched to adjust it for the schema validation.
+
         :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
 
             If a SchemaValidator is merged with another SchemaValidator, this option applies to
             keys that are not defined in each of the schemas.
+
+        :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
         """
 
-        self._alias_to_name = {}
+        self._patch_schema(schema)
+        self._schemas = [schema]
 
+        self._key_id_to_field_name = {}
         for field in schema.__fields__.values():
-            field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
-            self._alias_to_name[field.alias] = field.name
+            raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name
+            self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
+        self._allow_extra_keys = allow_extra_keys
 
-            # Because validate() interface provides one key at a time
+    @staticmethod
+    def _patch_schema(schema: pydantic.BaseModel):
+        # We set required=False because the validate() interface provides only one key at a time
+        for field in schema.__fields__.values():
             field.required = False
-        schema.Config.extra = pydantic.Extra.forbid
 
-        self._schemas = [schema]
-        self._allow_extra_keys = allow_extra_keys
+        schema.Config.extra = pydantic.Extra.forbid
 
     def validate(self, record: DHTRecord) -> bool:
         """
@@ -69,12 +80,18 @@ class SchemaValidator(RecordValidatorBase):
            .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
         """
 
+        if record.key not in self._key_id_to_field_name:
+            if not self._allow_extra_keys:
+                logger.debug(f"Record {record} has a key ID that is not defined in any of the "
+                             f"schemas (therefore, the raw key is unknown)")
+            return self._allow_extra_keys
+
         try:
             record = self._deserialize_record(record)
         except ValueError as e:
-            logger.warning(e)
+            logger.debug(e)
             return False
-        [key_alias] = list(record.keys())
+        [field_name] = list(record.keys())
 
         n_outside_schema = 0
         validation_errors = []
@@ -82,54 +99,33 @@ class SchemaValidator(RecordValidatorBase):
             try:
                 parsed_record = schema.parse_obj(record)
             except pydantic.ValidationError as e:
-                if self._is_failed_due_to_extra_field(e):
-                    n_outside_schema += 1
-                else:
+                if not self._is_failed_due_to_extra_field(e):
                     validation_errors.append(e)
                 continue
 
-            parsed_value = parsed_record.dict(by_alias=True)[key_alias]
-            if parsed_value != record[key_alias]:
+            parsed_value = parsed_record.dict(by_alias=True)[field_name]
+            if parsed_value != record[field_name]:
                 validation_errors.append(ValueError(
-                    f"Value {record[key_alias]} needed type conversions to match "
+                    f"The record {record} needed type conversions to match "
                     f"the schema: {parsed_value}. Type conversions are not allowed"))
             else:
                 return True
 
-        readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}
-
-        if n_outside_schema == len(self._schemas):
-            if not self._allow_extra_keys:
-                logger.warning(f"Record {readable_record} contains a field that "
-                               f"is not defined in each of the schemas")
-            return self._allow_extra_keys
-
-        logger.warning(
-            f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
+        logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
         return False
 
-    @staticmethod
-    def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
-        key_alias = SchemaValidator._key_id_to_str(record.key)
+    def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
+        field_name = self._key_id_to_field_name[record.key]
         deserialized_value = DHTProtocol.serializer.loads(record.value)
         if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
             deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
-            return {key_alias: {deserialized_subkey: deserialized_value}}
+            return {field_name: {deserialized_subkey: deserialized_value}}
         else:
             if isinstance(deserialized_value, dict):
                 raise ValueError(
                     f'Record {record} contains an improperly serialized dictionary (you must use '
                     f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
-            return {key_alias: deserialized_value}
-
-    @staticmethod
-    def _key_id_to_str(key_id: bytes) -> str:
-        """
-        Represent ``key_id`` as a ``str`` since Pydantic does not support field aliases
-        of type ``bytes``.
-        """
-
-        return binascii.hexlify(key_id).decode()
+            return {field_name: deserialized_value}
 
     @staticmethod
     def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
@@ -144,11 +140,18 @@ class SchemaValidator(RecordValidatorBase):
         if not isinstance(other, SchemaValidator):
             return False
 
-        self._alias_to_name.update(other._alias_to_name)
         self._schemas.extend(other._schemas)
+        self._key_id_to_field_name.update(other._key_id_to_field_name)
         self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
         return True
 
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+
+        # If unpickling happens in another process, the previous model modifications may be lost
+        for schema in self._schemas:
+            self._patch_schema(schema)
+
 
 def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
     """
@@ -170,3 +173,6 @@ def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
             return value
 
     return ConstrainedBytesWithRegex
+
+
+BytesWithPublicKey = conbytes(regex=b'.*' + RSASignatureValidator.PUBLIC_KEY_REGEX + b'.*')

+ 59 - 29
hivemind/optim/collaborative.py

@@ -1,17 +1,22 @@
 from __future__ import annotations
+
+import logging
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
-from typing import Optional, Iterator
-import logging
+from typing import Dict, Optional, Iterator
 
-import torch
 import numpy as np
+import torch
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
+from hivemind.client.averaging.training import TrainingAverager
 from hivemind.dht import DHT
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.client.averaging.training import TrainingAverager
-from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.utils import Endpoint, ValueWithExpiration, get_dht_time, get_logger
+
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -37,6 +42,19 @@ class CollaborationState:
         self.eta_next_step = float('inf')
 
 
+class TrainingState(BaseModel):
+    endpoint: Endpoint
+    step: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
+
+
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
     An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
@@ -87,6 +105,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
                  client_mode: bool = False, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
+
+        signature_validator = RSASignatureValidator()
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix),
+                            signature_validator])
+
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
@@ -103,6 +127,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.samples_processed = 0
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -167,6 +192,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
+            self.samples_processed += batch_size
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.should_report_progress.set()
 
@@ -209,6 +235,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.update_scheduler()
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
+            logger.info(f"Your current contribution: {self.samples_processed} samples")
+
             return group_info
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
@@ -263,12 +291,18 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.should_report_progress.clear()
             with self.lock_local_progress:
                 current_time = get_dht_time()
-                local_state_info = [self.local_step, self.local_samples_accumulated,
-                                    self.performance_ema.samples_per_second, current_time, not self.averager.listen]
-
-            assert self.is_valid_peer_state(local_state_info), local_state_info
-            self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=local_state_info,
-                           expiration_time=current_time + self.metadata_expiration, return_future=True)
+                local_state_info = TrainingState(
+                    endpoint=self.averager.endpoint,
+                    step=self.local_step,
+                    samples_accumulated=self.local_samples_accumulated,
+                    samples_per_second=self.performance_ema.samples_per_second,
+                    time=current_time,
+                    client_mode=not self.averager.listen)
+
+            self.dht.store(key=self.training_progress_key, subkey=self._local_public_key,
+                           value=local_state_info.dict(),
+                           expiration_time=current_time + self.metadata_expiration,
+                           return_future=True)
 
     def check_collaboration_state_periodically(self):
         """
@@ -296,24 +330,25 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                                       num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
                                       next_fetch_time=current_time + self.default_refresh_period)
 
-        valid_peer_states = [peer_state.value for peer_state in response.values()
-                             if isinstance(peer_state, ValueWithExpiration)
-                             and self.is_valid_peer_state(peer_state.value)]
+        valid_peer_states = [TrainingState.parse_obj(peer_state.value)
+                             for peer_state in response.values()
+                             if peer_state.value is not None]
 
         num_peers = len(valid_peer_states)
-        num_clients = sum(is_client for *_, is_client in valid_peer_states)
+        num_clients = sum(state.client_mode for state in valid_peer_states)
         global_optimizer_step = self.local_step
-        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
-            if not is_client:
-                global_optimizer_step = max(global_optimizer_step, opt_step)
+        for state in valid_peer_states:
+            if not state.client_mode:
+                global_optimizer_step = max(global_optimizer_step, state.step)
 
         total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
 
-        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
-            total_samples_per_second += samples_per_second
-            if opt_step == global_optimizer_step:
-                total_samples_accumulated += samples_accumulated
-                estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
+        for state in valid_peer_states:
+            total_samples_per_second += state.samples_per_second
+            if state.step == global_optimizer_step:
+                total_samples_accumulated += state.samples_accumulated
+                estimated_current_samples += (state.samples_accumulated +
+                                              max(0, current_time - state.time) * state.samples_per_second)
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
 
@@ -337,11 +372,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                              f"call zero_grad manually. Gradients will be refreshed internally.")
         return self.opt.zero_grad(*args, **kwargs)
 
-    @staticmethod
-    def is_valid_peer_state(state):
-        return isinstance(state, (list, tuple)) and len(state) == 5 \
-               and all(map(isinstance, state, (int, int, float, float, bool)))
-
     def update_scheduler(self):
         if self.scheduler:
             while self.scheduler._step_count < self.local_step:
@@ -351,7 +381,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         logger.debug("Shutting down averager...")
         self.averager.shutdown()
         logger.debug("Sending goodbye to peers...")
-        self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=None,
+        self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None,
                        expiration_time=get_dht_time() + self.metadata_expiration)
         logger.debug(f"{self.__class__.__name__} is shut down.")
 

+ 93 - 4
tests/test_dht_crypto.py

@@ -1,24 +1,28 @@
 import dataclasses
+import pickle
+import multiprocessing as mp
 
 import pytest
 
+import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.node import LOCALHOST
 from hivemind.dht.validation import DHTRecord
 
 
 def test_rsa_signature_validator():
     receiver_validator = RSASignatureValidator()
-    sender_validator = RSASignatureValidator()
-    mallory_validator = RSASignatureValidator()
+    sender_validator = RSASignatureValidator(ignore_cached_key=True)
+    mallory_validator = RSASignatureValidator(ignore_cached_key=True)
 
     plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
                              expiration_time=get_dht_time() + 10)
     protected_records = [
         dataclasses.replace(plain_record,
-                            key=plain_record.key + sender_validator.ownership_marker),
+                            key=plain_record.key + sender_validator.local_public_key),
         dataclasses.replace(plain_record,
-                            subkey=plain_record.subkey + sender_validator.ownership_marker),
+                            subkey=plain_record.subkey + sender_validator.local_public_key),
     ]
 
     # test 1: Non-protected record (no signature added)
@@ -41,3 +45,88 @@ def test_rsa_signature_validator():
                        for record in protected_records]  # With someone else's signature
     for record in signed_records:
         assert not receiver_validator.validate(record)
+
+
+def test_cached_key():
+    first_validator = RSASignatureValidator()
+    second_validator = RSASignatureValidator()
+    assert first_validator.local_public_key == second_validator.local_public_key
+
+    third_validator = RSASignatureValidator(ignore_cached_key=True)
+    assert first_validator.local_public_key != third_validator.local_public_key
+
+
+def test_validator_instance_is_picklable():
+    # Needs to be picklable because the validator instance may be sent between processes
+
+    original_validator = RSASignatureValidator()
+    unpickled_validator = pickle.loads(pickle.dumps(original_validator))
+
+    # To check that the private key was pickled and unpickled correctly, we sign a record
+    # with the original public key using the unpickled validator and then validate the signature
+
+    record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key,
+                       value=b'value', expiration_time=get_dht_time() + 10)
+    signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
+
+    assert b'[signature:' in signed_record.value
+    assert original_validator.validate(signed_record)
+    assert unpickled_validator.validate(signed_record)
+
+
+def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
+    validator = conn.recv()
+    record = conn.recv()
+
+    record = dataclasses.replace(record, value=validator.sign_value(record))
+
+    conn.send(record)
+
+
+def test_signing_in_different_process():
+    parent_conn, child_conn = mp.Pipe()
+    process = mp.Process(target=get_signed_record, args=[child_conn])
+    process.start()
+
+    validator = RSASignatureValidator()
+    parent_conn.send(validator)
+
+    record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key,
+                       value=b'value', expiration_time=get_dht_time() + 10)
+    parent_conn.send(record)
+
+    signed_record = parent_conn.recv()
+    assert b'[signature:' in signed_record.value
+    assert validator.validate(signed_record)
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_signatures():
+    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
+    bob = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    mallory = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    key = b'key'
+    subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key
+
+    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
+
+    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
+
+    store_ok = await mallory.store(key, b'updated_fake_value',
+                                   hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'

+ 0 - 31
tests/test_dht_node.py

@@ -10,7 +10,6 @@ import pytest
 
 import hivemind
 from hivemind import get_dht_time, replace_port
-from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
@@ -454,33 +453,3 @@ async def test_dhtnode_edge_cases():
         assert stored is not None
         assert subkey in stored.value
         assert stored.value[subkey].value == value
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_signatures():
-    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
-    bob = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
-    mallory = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
-
-    key = b'key'
-    subkey = b'protected_subkey' + bob.protocol.record_validator.ownership_marker
-
-    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
-
-    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
-
-    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
-
-    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
-
-    store_ok = await mallory.store(key, b'updated_fake_value',
-                                   hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'

+ 79 - 43
tests/test_dht_schema.py

@@ -2,22 +2,24 @@ import re
 
 import pytest
 from pydantic import BaseModel, StrictFloat, StrictInt, conint
-from typing import Dict, List
+from typing import Dict
 
+import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
-from hivemind.dht.schema import SchemaValidator, conbytes
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 
+class SampleSchema(BaseModel):
+    experiment_name: bytes
+    n_batches: Dict[bytes, conint(ge=0, strict=True)]
+    signed_data: Dict[BytesWithPublicKey, bytes]
+
+
 @pytest.fixture
 async def dht_nodes_with_schema():
-    class Schema(BaseModel):
-        experiment_name: bytes
-        n_batches: Dict[bytes, conint(ge=0, strict=True)]
-        signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
-
-    validator = SchemaValidator(Schema)
+    validator = SchemaValidator(SampleSchema)
 
     alice = await DHTNode.create(record_validator=validator)
     bob = await DHTNode.create(
@@ -31,17 +33,17 @@ async def test_expecting_regular_value(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Regular value (bytes) expected
-    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', 666, get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10,
+    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
+    assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
                                subkey=b'subkey')
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store(b'experiment_name', [], get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', [1, 2, 3], get_dht_time() + 10)
+    assert not await bob.store('experiment_name', [], get_dht_time() + 10)
+    assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
 
 
 @pytest.mark.forked
@@ -50,27 +52,27 @@ async def test_expecting_dictionary(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Dictionary (bytes -> non-negative int) expected
-    assert await bob.store(b'n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
-    assert await bob.store(b'n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
-    assert not await bob.store(b'n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10)
-    assert not await bob.store(b'n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
-    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10, subkey=666)
+    assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
+    assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
+    assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', 666, get_dht_time() + 10)
+    assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
+    assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
 
     # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
-    assert not await bob.store(b'n_batches', {b'uid3': 779}, get_dht_time() + 10)
+    assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store(b'n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', [], get_dht_time() + 10)
-    assert not await bob.store(b'n_batches', [(b'uid3', 779)], get_dht_time() + 10)
+    assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', [], get_dht_time() + 10)
+    assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
 
     # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
-    assert not await bob.store(b'n_batches', '', get_dht_time() + 10)
+    assert not await bob.store('n_batches', '', get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get(b'n_batches', latest=True)).value
+        dictionary = (await peer.get('n_batches', latest=True)).value
         assert (len(dictionary) == 2 and
                 dictionary[b'uid1'].value == 777 and
                 dictionary[b'uid2'].value == 778)
@@ -83,13 +85,13 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
 
     # Subkeys expected to contain a public key
     # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
-    assert await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+    assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                            subkey=b'uid[owner:public-key]')
-    assert not await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+    assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                                subkey=b'uid-without-public-key')
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get(b'signed_data', latest=True)).value
+        dictionary = (await peer.get('signed_data', latest=True)).value
         assert (len(dictionary) == 1 and
                 dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
 
@@ -111,17 +113,38 @@ async def test_keys_outside_schema(dht_nodes_with_schema):
         bob = await DHTNode.create(
             record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
 
-        store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
+        store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
         assert store_ok == allow_extra_keys
 
         for peer in [alice, bob]:
-            result = await peer.get(b'unknown_key', latest=True)
+            result = await peer.get('unknown_key', latest=True)
             if allow_extra_keys:
                 assert result.value == b'foo_bar'
             else:
                 assert result is None
 
 
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_prefix():
+    class Schema(BaseModel):
+        field: StrictInt
+
+    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
+
+    alice = await DHTNode.create(record_validator=validator)
+    bob = await DHTNode.create(
+        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    assert await bob.store('prefix_field', 777, get_dht_time() + 10)
+    assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
+    assert not await bob.store('field', 777, get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get('prefix_field', latest=True)).value == 777
+        assert (await peer.get('field', latest=True)) is None
+
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_merging_schema_validators(dht_nodes_with_schema):
@@ -147,18 +170,31 @@ async def test_merging_schema_validators(dht_nodes_with_schema):
         for peer in [alice, bob]:
             assert peer.protocol.record_validator.merge_with(new_validator)
 
-    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert await bob.store(b'some_field', 777, get_dht_time() + 10)
-    assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
-    assert await bob.store(b'another_field', 42, get_dht_time() + 10)
-    assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert await bob.store('some_field', 777, get_dht_time() + 10)
+    assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store('another_field', 42, get_dht_time() + 10)
+    assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
 
-    # Unkown keys are allowed since the first schema is created with allow_extra_keys=True
-    assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)
+    # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
+    assert await bob.store('unknown_key', 999, get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
-        assert (await peer.get(b'some_field', latest=True)).value == 777
-        assert (await peer.get(b'another_field', latest=True)).value == 'string_value'
+        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get('some_field', latest=True)).value == 777
+        assert (await peer.get('another_field', latest=True)).value == 'string_value'
+
+        assert (await peer.get('unknown_key', latest=True)).value == 999
+
+
+@pytest.mark.forked
+def test_sending_validator_instance_between_processes():
+    alice = hivemind.DHT(start=True)
+    bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    alice.add_validators([SchemaValidator(SampleSchema)])
+    bob.add_validators([SchemaValidator(SampleSchema)])
 
-        assert (await peer.get(b'unknown_key', latest=True)).value == 999
+    assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert not bob.store('experiment_name', 777, get_dht_time() + 10)
+    assert alice.get('experiment_name', latest=True).value == b'foo_bar'

+ 16 - 16
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ import hivemind
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
-from hivemind.dht.schema import SchemaValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
 
 
@@ -18,7 +18,7 @@ class SchemaA(BaseModel):
 
 
 class SchemaB(BaseModel):
-    field_b: Dict[bytes, StrictInt]
+    field_b: Dict[BytesWithPublicKey, StrictInt]
 
 
 @pytest.fixture
@@ -40,9 +40,9 @@ def test_composite_validator(validators_for_app):
         [SchemaValidator, RSASignatureValidator])
     assert len(validator._validators[0]._schemas) == 2
 
-    public_key = validators_for_app['A'][0].ownership_marker
-    record = DHTRecord(key=DHTID.generate(source=b'field_b').to_bytes(),
-                       subkey=DHTProtocol.serializer.dumps(public_key),
+    local_public_key = validators_for_app['A'][0].local_public_key
+    record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
+                       subkey=DHTProtocol.serializer.dumps(local_public_key),
                        value=DHTProtocol.serializer.dumps(777),
                        expiration_time=hivemind.get_dht_time() + 10)
 
@@ -53,7 +53,7 @@ def test_composite_validator(validators_for_app):
     assert validator.validate(signed_record)
     assert validator.strip_value(signed_record) == record.value
 
-    record = DHTRecord(key=DHTID.generate(source=b'unknown_key').to_bytes(),
+    record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
                        subkey=DHTProtocol.IS_REGULAR_VALUE,
                        value=DHTProtocol.serializer.dumps(777),
                        expiration_time=hivemind.get_dht_time() + 10)
@@ -77,17 +77,17 @@ def test_dht_add_validators(validators_for_app):
     # After starting the process, other apps may add new validators to the existing DHT
     dht.add_validators(validators_for_app['B'])
 
-    assert dht.store(b'field_a', b'bytes_value', hivemind.get_dht_time() + 10)
-    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+    assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
+    assert dht.get('field_a', latest=True).value == b'bytes_value'
 
-    assert not dht.store(b'field_a', 666, hivemind.get_dht_time() + 10)
-    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+    assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
+    assert dht.get('field_a', latest=True).value == b'bytes_value'
 
-    public_key = validators_for_app['A'][0].ownership_marker
-    assert dht.store(b'field_b', 777, hivemind.get_dht_time() + 10, subkey=public_key)
-    dictionary = dht.get(b'field_b', latest=True).value
+    local_public_key = validators_for_app['A'][0].local_public_key
+    assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
+    dictionary = dht.get('field_b', latest=True).value
     assert (len(dictionary) == 1 and
-            dictionary[public_key].value == 777)
+            dictionary[local_public_key].value == 777)
 
-    assert not dht.store(b'unknown_key', 666, hivemind.get_dht_time() + 10)
-    assert dht.get(b'unknown_key', latest=True) is None
+    assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
+    assert dht.get('unknown_key', latest=True) is None