Selaa lähdekoodia

Protect training progress and metrics with signatures and DHT schema validation (#250)

This PR implements:

1. Validating DHT schema for training progress and metrics
2. Signing the training progress and the metrics (so a peer can't change the stats of someone else)
3. Using the local public key as a replacement for the peer's UUID
4. Using type-validated data structures for the stats in the code instead of plain lists
5. Refactoring of the validator system
Aleksandr Borzunov 4 vuotta sitten
vanhempi
commit
3bde6188fe

+ 25 - 0
examples/albert/metrics_utils.py

@@ -0,0 +1,25 @@
+from typing import Dict, List, Tuple
+
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
+from hivemind.dht.validation import RecordValidatorBase
+from pydantic import BaseModel, StrictFloat, confloat, conint
+
+
+class LocalMetrics(BaseModel):
+    step: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    loss: StrictFloat
+    mini_steps: conint(ge=0, strict=True)
+
+
+class MetricSchema(BaseModel):
+    metrics: Dict[BytesWithPublicKey, LocalMetrics]
+
+
+def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+    signature_validator = RSASignatureValidator()
+    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
+                  signature_validator]
+    return validators, signature_validator.local_public_key

+ 17 - 12
examples/albert/run_first_peer.py

@@ -1,13 +1,15 @@
 #!/usr/bin/env python
 
-import time
 import argparse
-import wandb
-from whatsmyip.providers import GoogleDnsProvider
-from whatsmyip.ip import get_ip
+import time
 
 import hivemind
+import wandb
 from hivemind.utils.logging import get_logger
+from whatsmyip.ip import get_ip
+from whatsmyip.providers import GoogleDnsProvider
+
+import metrics_utils
 
 
 logger = get_logger(__name__)
@@ -32,7 +34,9 @@ if __name__ == '__main__':
         logger.warning("No address specified. Attempting to infer address from DNS.")
         args.address = get_ip(GoogleDnsProvider)
 
-    dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*")
+    validators, local_public_key = metrics_utils.make_validators(args.experiment_prefix)
+    dht = hivemind.DHT(start=True, listen_on=args.listen_on, endpoint=f"{args.address}:*",
+                       record_validators=validators)
     logger.info(f"Running DHT root at {args.address}:{dht.port}")
 
     wandb.init(project=args.wandb_project)
@@ -42,8 +46,9 @@ if __name__ == '__main__':
         metrics_dict = dht.get(args.experiment_prefix + '_metrics', latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
-            metrics = [metrics_dict[peer].value for peer in metrics_dict]
-            latest_step = max(metrics)[0]
+            metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
+                       for peer in metrics_dict]
+            latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
                 current_step = latest_step
                 alive_peers = 0
@@ -52,12 +57,12 @@ if __name__ == '__main__':
                 num_samples = 0
                 sum_perf = 0
                 sum_mini_steps = 0
-                for step, perf, samples, loss, mini_steps in metrics:
-                    sum_loss += loss
+                for item in metrics:
+                    sum_loss += item.loss
                     alive_peers += 1
-                    sum_perf += perf
-                    num_samples += samples
-                    sum_mini_steps += mini_steps
+                    sum_perf += item.samples_per_second
+                    num_samples += item.samples_accumulated
+                    sum_mini_steps += item.mini_steps
                 wandb.log({
                     "loss": sum_loss / sum_mini_steps,
                     "alive peers": alive_peers,

+ 27 - 20
examples/albert/run_trainer.py

@@ -5,10 +5,11 @@ import os
 from dataclasses import dataclass, field, asdict
 from pathlib import Path
 from typing import Optional, Dict, Any, List
-import uuid
 
-from datasets import load_from_disk
+import hivemind
+import torch
 import transformers
+from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from transformers import (set_seed, HfArgumentParser, TrainingArguments,
                           DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
@@ -16,9 +17,9 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer_utils import is_main_process
 from transformers.trainer import Trainer
 from torch_optimizer import Lamb
-import torch
 
-import hivemind
+import metrics_utils
+
 
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -35,7 +36,6 @@ class CollaborationArguments:
     averaging_timeout: float = 30.0  # give up on averaging step after this many seconds
     target_batch_size: int = 4096  # perform optimizer step after all peers collectively accumulate this many samples
     client_mode: bool = False  # if True, runs training without incoming connections, in a firewall-compatible mode
-    trainer_uuid: str = uuid.uuid4().hex  # this peer's name - used when publishing metadata to DHT, default = random
 
     # optional tweaks
     target_group_size: int = 256  # maximum group size for all-reduce
@@ -165,11 +165,12 @@ def get_optimizer_and_scheduler(training_args, model):
 
 class CollaborativeCallback(transformers.TrainerCallback):
     def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
-                 model: torch.nn.Module, trainer_uuid: str, statistics_expiration: float):
+                 model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float):
         super().__init__()
         self.model = model
         self.dht, self.collaborative_optimizer = dht, optimizer
-        self.trainer_uuid, self.statistics_expiration = trainer_uuid, statistics_expiration
+        self.local_public_key = local_public_key
+        self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
         self.previous_state = self.get_current_state()
         self.samples = 0
@@ -190,15 +191,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
 
-                statistics = [self.collaborative_optimizer.local_step,
-                              self.collaborative_optimizer.performance_ema.samples_per_second,
-                              self.samples,
-                              self.loss,
-                              self.steps]
+                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                statistics = metrics_utils.LocalMetrics(
+                    step=self.collaborative_optimizer.local_step,
+                    samples_per_second=samples_per_second,
+                    samples_accumulated=self.samples,
+                    loss=self.loss,
+                    mini_steps=self.steps)
                 self.loss = 0
                 self.steps = 0
-                self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
-                               value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
+                               subkey=self.local_public_key, value=statistics.dict(),
+                               expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
@@ -270,13 +274,15 @@ def main():
 
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
 
+    validators, local_public_key = metrics_utils.make_validators(
+        collaboration_args_dict['experiment_prefix'])
     dht = hivemind.DHT(
-        initial_peers=collaboration_args_dict.pop('initial_peers'),
-        listen=not collaboration_args_dict['client_mode'], listen_on=collaboration_args_dict.pop('dht_listen_on'),
-        endpoint=collaboration_args_dict.pop('endpoint'), start=True)
+        start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
+        listen=not collaboration_args_dict['client_mode'],
+        listen_on=collaboration_args_dict.pop('dht_listen_on'),
+        endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
-    trainer_uuid = collaboration_args_dict.pop('trainer_uuid')
     statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
     adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
                                  - collaboration_args_dict.pop('batch_size_lead')
@@ -292,7 +298,7 @@ def main():
     class TrainerWithIndependentShuffling(Trainer):
         def get_train_dataloader(self) -> DataLoader:
             """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
-            torch.manual_seed(hash(trainer_uuid))
+            torch.manual_seed(hash(local_public_key))
             return super().get_train_dataloader()
 
     trainer = TrainerWithIndependentShuffling(
@@ -300,7 +306,8 @@ def main():
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
-        callbacks=[CollaborativeCallback(dht, collaborative_optimizer, model, trainer_uuid, statistics_expiration)]
+        callbacks=[CollaborativeCallback(
+            dht, collaborative_optimizer, model, local_public_key, statistics_expiration)]
     )
     trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
     trainer.remove_callback(transformers.trainer_callback.ProgressCallback)

+ 30 - 17
hivemind/dht/crypto.py

@@ -16,38 +16,50 @@ logger = get_logger(__name__)
 class RSASignatureValidator(RecordValidatorBase):
     """
     Introduces a notion of *protected records* whose key/subkey contains substring
-    "[owner:ssh-rsa ...]" (the format can be changed) with an RSA public key of the owner.
+    "[owner:ssh-rsa ...]" with an RSA public key of the owner.
 
     If this validator is used, changes to such records always must be signed with
     the corresponding private key (so only the owner can change them).
     """
 
-    def __init__(self,
-                 marker_format: bytes=b'[owner:_key_]',
-                 signature_format: bytes=b'[signature:_value_]'):
-        self._marker_re = re.compile(re.escape(marker_format).replace(b'_key_', rb'(.+?)'))
+    PUBLIC_KEY_FORMAT = b'[owner:_key_]'
+    SIGNATURE_FORMAT = b'[signature:_value_]'
 
-        self._signature_format = signature_format
-        self._signature_re = re.compile(re.escape(signature_format).replace(b'_value_', rb'(.+?)'))
+    PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b'_key_', rb'(.+?)')
+    _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
+    _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b'_value_', rb'(.+?)'))
 
-        self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+    _cached_private_key = None
+
+    def __init__(self, *, ignore_cached_key=False):
+        if self._cached_private_key is None or ignore_cached_key:
+            # Since generating a private key takes ~100 ms, we cache it for future validator
+            # instances in the same process (unless ignore_cached_key=True)
+            self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+            if not ignore_cached_key:
+                RSASignatureValidator._cached_private_key = self._private_key
+        else:
+            self._private_key = RSASignatureValidator._cached_private_key
 
         serialized_public_key = self._private_key.public_key().public_bytes(
             encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
-        self._ownership_marker = marker_format.replace(b'_key_', serialized_public_key)
+        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
+
+        self._init_signature_params()
 
+    def _init_signature_params(self) -> None:
         self._padding = padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                                     salt_length=padding.PSS.MAX_LENGTH)
         self._hash_algorithm = hashes.SHA256()
 
     @property
-    def ownership_marker(self) -> bytes:
-        return self._ownership_marker
+    def local_public_key(self) -> bytes:
+        return self._local_public_key
 
     def validate(self, record: DHTRecord) -> bool:
-        public_keys = self._marker_re.findall(record.key)
+        public_keys = self._PUBLIC_KEY_RE.findall(record.key)
         if record.subkey is not None:
-            public_keys += self._marker_re.findall(record.subkey)
+            public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
         if not public_keys:
             return True  # The record is not protected with a public key
 
@@ -56,7 +68,7 @@ class RSASignatureValidator(RecordValidatorBase):
             return False
         public_key = serialization.load_ssh_public_key(public_keys[0])
 
-        signatures = self._signature_re.findall(record.value)
+        signatures = self._SIGNATURE_RE.findall(record.value)
         if len(signatures) != 1:
             logger.debug(f"Record should have exactly one signature in {record}")
             return False
@@ -73,16 +85,16 @@ class RSASignatureValidator(RecordValidatorBase):
             return False
 
     def sign_value(self, record: DHTRecord) -> bytes:
-        if self._ownership_marker not in record.key and self._ownership_marker not in record.subkey:
+        if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
             return record.value
 
         signature = self._private_key.sign(self._serialize_record(record),
                                            self._padding, self._hash_algorithm)
         signature = base64.b64encode(signature)
-        return record.value + self._signature_format.replace(b'_value_', signature)
+        return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
 
     def strip_value(self, record: DHTRecord) -> bytes:
-        return self._signature_re.sub(b'', record.value)
+        return self._SIGNATURE_RE.sub(b'', record.value)
 
     def _serialize_record(self, record: DHTRecord) -> bytes:
         return MSGPackSerializer.dumps(dataclasses.astuple(record))
@@ -113,3 +125,4 @@ class RSASignatureValidator(RecordValidatorBase):
     def __setstate__(self, state):
         self.__dict__.update(state)
         self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._init_signature_params()

+ 48 - 42
hivemind/dht/schema.py

@@ -1,9 +1,11 @@
 import binascii
 import re
-from typing import Any, Dict, Type
+from contextlib import contextmanager
+from typing import Any, Dict, Optional, Type
 
 import pydantic
 
+from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTKey
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
@@ -19,7 +21,8 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
+    def __init__(self, schema: pydantic.BaseModel, *,
+                 allow_extra_keys: bool=True, prefix: Optional[str]=None):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 
@@ -28,24 +31,32 @@ class SchemaValidator(RecordValidatorBase):
             ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
             See the validate() docstring for details.
 
+            The model will be patched to adjust it for the schema validation.
+
         :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
 
             If a SchemaValidator is merged with another SchemaValidator, this option applies to
             keys that are not defined in each of the schemas.
+
+        :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
         """
 
-        self._alias_to_name = {}
+        self._patch_schema(schema)
+        self._schemas = [schema]
 
+        self._key_id_to_field_name = {}
         for field in schema.__fields__.values():
-            field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
-            self._alias_to_name[field.alias] = field.name
+            raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name
+            self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
+        self._allow_extra_keys = allow_extra_keys
 
-            # Because validate() interface provides one key at a time
+    @staticmethod
+    def _patch_schema(schema: pydantic.BaseModel):
+        # We set required=False because the validate() interface provides only one key at a time
+        for field in schema.__fields__.values():
             field.required = False
-        schema.Config.extra = pydantic.Extra.forbid
 
-        self._schemas = [schema]
-        self._allow_extra_keys = allow_extra_keys
+        schema.Config.extra = pydantic.Extra.forbid
 
     def validate(self, record: DHTRecord) -> bool:
         """
@@ -69,12 +80,18 @@ class SchemaValidator(RecordValidatorBase):
            .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
         """
 
+        if record.key not in self._key_id_to_field_name:
+            if not self._allow_extra_keys:
+                logger.debug(f"Record {record} has a key ID that is not defined in any of the "
+                             f"schemas (therefore, the raw key is unknown)")
+            return self._allow_extra_keys
+
         try:
             record = self._deserialize_record(record)
         except ValueError as e:
-            logger.warning(e)
+            logger.debug(e)
             return False
-        [key_alias] = list(record.keys())
+        [field_name] = list(record.keys())
 
         n_outside_schema = 0
         validation_errors = []
@@ -82,54 +99,33 @@ class SchemaValidator(RecordValidatorBase):
             try:
                 parsed_record = schema.parse_obj(record)
             except pydantic.ValidationError as e:
-                if self._is_failed_due_to_extra_field(e):
-                    n_outside_schema += 1
-                else:
+                if not self._is_failed_due_to_extra_field(e):
                     validation_errors.append(e)
                 continue
 
-            parsed_value = parsed_record.dict(by_alias=True)[key_alias]
-            if parsed_value != record[key_alias]:
+            parsed_value = parsed_record.dict(by_alias=True)[field_name]
+            if parsed_value != record[field_name]:
                 validation_errors.append(ValueError(
-                    f"Value {record[key_alias]} needed type conversions to match "
+                    f"The record {record} needed type conversions to match "
                     f"the schema: {parsed_value}. Type conversions are not allowed"))
             else:
                 return True
 
-        readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}
-
-        if n_outside_schema == len(self._schemas):
-            if not self._allow_extra_keys:
-                logger.warning(f"Record {readable_record} contains a field that "
-                               f"is not defined in each of the schemas")
-            return self._allow_extra_keys
-
-        logger.warning(
-            f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
+        logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
         return False
 
-    @staticmethod
-    def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
-        key_alias = SchemaValidator._key_id_to_str(record.key)
+    def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
+        field_name = self._key_id_to_field_name[record.key]
         deserialized_value = DHTProtocol.serializer.loads(record.value)
         if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
             deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
-            return {key_alias: {deserialized_subkey: deserialized_value}}
+            return {field_name: {deserialized_subkey: deserialized_value}}
         else:
             if isinstance(deserialized_value, dict):
                 raise ValueError(
                     f'Record {record} contains an improperly serialized dictionary (you must use '
                     f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
-            return {key_alias: deserialized_value}
-
-    @staticmethod
-    def _key_id_to_str(key_id: bytes) -> str:
-        """
-        Represent ``key_id`` as a ``str`` since Pydantic does not support field aliases
-        of type ``bytes``.
-        """
-
-        return binascii.hexlify(key_id).decode()
+            return {field_name: deserialized_value}
 
     @staticmethod
     def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
@@ -144,11 +140,18 @@ class SchemaValidator(RecordValidatorBase):
         if not isinstance(other, SchemaValidator):
             return False
 
-        self._alias_to_name.update(other._alias_to_name)
         self._schemas.extend(other._schemas)
+        self._key_id_to_field_name.update(other._key_id_to_field_name)
         self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
         return True
 
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+
+        # If unpickling happens in another process, the previous model modifications may be lost
+        for schema in self._schemas:
+            self._patch_schema(schema)
+
 
 def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
     """
@@ -170,3 +173,6 @@ def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
             return value
 
     return ConstrainedBytesWithRegex
+
+
+BytesWithPublicKey = conbytes(regex=b'.*' + RSASignatureValidator.PUBLIC_KEY_REGEX + b'.*')

+ 55 - 29
hivemind/optim/collaborative.py

@@ -1,17 +1,22 @@
 from __future__ import annotations
+
+import logging
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
-from typing import Optional, Iterator
-import logging
+from typing import Dict, Optional, Iterator
 
-import torch
 import numpy as np
+import torch
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
+from hivemind.client.averaging.training import TrainingAverager
 from hivemind.dht import DHT
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.client.averaging.training import TrainingAverager
-from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.utils import Endpoint, ValueWithExpiration, get_dht_time, get_logger
+
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -37,6 +42,19 @@ class CollaborationState:
         self.eta_next_step = float('inf')
 
 
+class TrainingState(BaseModel):
+    endpoint: Endpoint
+    step: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
+
+
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
     An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
@@ -87,6 +105,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
                  client_mode: bool = False, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
+
+        signature_validator = RSASignatureValidator()
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix),
+                            signature_validator])
+
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
@@ -263,12 +287,18 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.should_report_progress.clear()
             with self.lock_local_progress:
                 current_time = get_dht_time()
-                local_state_info = [self.local_step, self.local_samples_accumulated,
-                                    self.performance_ema.samples_per_second, current_time, not self.averager.listen]
-
-            assert self.is_valid_peer_state(local_state_info), local_state_info
-            self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=local_state_info,
-                           expiration_time=current_time + self.metadata_expiration, return_future=True)
+                local_state_info = TrainingState(
+                    endpoint=self.averager.endpoint,
+                    step=self.local_step,
+                    samples_accumulated=self.local_samples_accumulated,
+                    samples_per_second=self.performance_ema.samples_per_second,
+                    time=current_time,
+                    client_mode=not self.averager.listen)
+
+            self.dht.store(key=self.training_progress_key, subkey=self._local_public_key,
+                           value=local_state_info.dict(),
+                           expiration_time=current_time + self.metadata_expiration,
+                           return_future=True)
 
     def check_collaboration_state_periodically(self):
         """
@@ -296,24 +326,25 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                                       num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
                                       next_fetch_time=current_time + self.default_refresh_period)
 
-        valid_peer_states = [peer_state.value for peer_state in response.values()
-                             if isinstance(peer_state, ValueWithExpiration)
-                             and self.is_valid_peer_state(peer_state.value)]
+        valid_peer_states = [TrainingState.parse_obj(peer_state.value)
+                             for peer_state in response.values()
+                             if peer_state.value is not None]
 
         num_peers = len(valid_peer_states)
-        num_clients = sum(is_client for *_, is_client in valid_peer_states)
+        num_clients = sum(state.client_mode for state in valid_peer_states)
         global_optimizer_step = self.local_step
-        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
-            if not is_client:
-                global_optimizer_step = max(global_optimizer_step, opt_step)
+        for state in valid_peer_states:
+            if not state.client_mode:
+                global_optimizer_step = max(global_optimizer_step, state.step)
 
         total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
 
-        for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
-            total_samples_per_second += samples_per_second
-            if opt_step == global_optimizer_step:
-                total_samples_accumulated += samples_accumulated
-                estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
+        for state in valid_peer_states:
+            total_samples_per_second += state.samples_per_second
+            if state.step == global_optimizer_step:
+                total_samples_accumulated += state.samples_accumulated
+                estimated_current_samples += (state.samples_accumulated +
+                                              max(0, current_time - state.time) * state.samples_per_second)
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
 
@@ -337,11 +368,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                              f"call zero_grad manually. Gradients will be refreshed internally.")
         return self.opt.zero_grad(*args, **kwargs)
 
-    @staticmethod
-    def is_valid_peer_state(state):
-        return isinstance(state, (list, tuple)) and len(state) == 5 \
-               and all(map(isinstance, state, (int, int, float, float, bool)))
-
     def update_scheduler(self):
         if self.scheduler:
             while self.scheduler._step_count < self.local_step:
@@ -351,7 +377,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         logger.debug("Shutting down averager...")
         self.averager.shutdown()
         logger.debug("Sending goodbye to peers...")
-        self.dht.store(self.training_progress_key, subkey=self.averager.endpoint, value=None,
+        self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None,
                        expiration_time=get_dht_time() + self.metadata_expiration)
         logger.debug(f"{self.__class__.__name__} is shut down.")
 

+ 93 - 4
tests/test_dht_crypto.py

@@ -1,24 +1,28 @@
 import dataclasses
+import pickle
+import multiprocessing as mp
 
 import pytest
 
+import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.node import LOCALHOST
 from hivemind.dht.validation import DHTRecord
 
 
 def test_rsa_signature_validator():
     receiver_validator = RSASignatureValidator()
-    sender_validator = RSASignatureValidator()
-    mallory_validator = RSASignatureValidator()
+    sender_validator = RSASignatureValidator(ignore_cached_key=True)
+    mallory_validator = RSASignatureValidator(ignore_cached_key=True)
 
     plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
                              expiration_time=get_dht_time() + 10)
     protected_records = [
         dataclasses.replace(plain_record,
-                            key=plain_record.key + sender_validator.ownership_marker),
+                            key=plain_record.key + sender_validator.local_public_key),
         dataclasses.replace(plain_record,
-                            subkey=plain_record.subkey + sender_validator.ownership_marker),
+                            subkey=plain_record.subkey + sender_validator.local_public_key),
     ]
 
     # test 1: Non-protected record (no signature added)
@@ -41,3 +45,88 @@ def test_rsa_signature_validator():
                        for record in protected_records]  # With someone else's signature
     for record in signed_records:
         assert not receiver_validator.validate(record)
+
+
+def test_cached_key():
+    first_validator = RSASignatureValidator()
+    second_validator = RSASignatureValidator()
+    assert first_validator.local_public_key == second_validator.local_public_key
+
+    third_validator = RSASignatureValidator(ignore_cached_key=True)
+    assert first_validator.local_public_key != third_validator.local_public_key
+
+
+def test_validator_instance_is_picklable():
+    # Needs to be picklable because the validator instance may be sent between processes
+
+    original_validator = RSASignatureValidator()
+    unpickled_validator = pickle.loads(pickle.dumps(original_validator))
+
+    # To check that the private key was pickled and unpickled correctly, we sign a record
+    # with the original public key using the unpickled validator and then validate the signature
+
+    record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key,
+                       value=b'value', expiration_time=get_dht_time() + 10)
+    signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
+
+    assert b'[signature:' in signed_record.value
+    assert original_validator.validate(signed_record)
+    assert unpickled_validator.validate(signed_record)
+
+
+def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
+    validator = conn.recv()
+    record = conn.recv()
+
+    record = dataclasses.replace(record, value=validator.sign_value(record))
+
+    conn.send(record)
+
+
+def test_signing_in_different_process():
+    parent_conn, child_conn = mp.Pipe()
+    process = mp.Process(target=get_signed_record, args=[child_conn])
+    process.start()
+
+    validator = RSASignatureValidator()
+    parent_conn.send(validator)
+
+    record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key,
+                       value=b'value', expiration_time=get_dht_time() + 10)
+    parent_conn.send(record)
+
+    signed_record = parent_conn.recv()
+    assert b'[signature:' in signed_record.value
+    assert validator.validate(signed_record)
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_signatures():
+    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
+    bob = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    mallory = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    key = b'key'
+    subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key
+
+    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
+
+    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
+
+    store_ok = await mallory.store(key, b'updated_fake_value',
+                                   hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'

+ 0 - 31
tests/test_dht_node.py

@@ -10,7 +10,6 @@ import pytest
 
 import hivemind
 from hivemind import get_dht_time, replace_port
-from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
@@ -454,33 +453,3 @@ async def test_dhtnode_edge_cases():
         assert stored is not None
         assert subkey in stored.value
         assert stored.value[subkey].value == value
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_signatures():
-    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
-    bob = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
-    mallory = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
-
-    key = b'key'
-    subkey = b'protected_subkey' + bob.protocol.record_validator.ownership_marker
-
-    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
-
-    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
-
-    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
-
-    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
-
-    store_ok = await mallory.store(key, b'updated_fake_value',
-                                   hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'

+ 79 - 43
tests/test_dht_schema.py

@@ -2,22 +2,24 @@ import re
 
 import pytest
 from pydantic import BaseModel, StrictFloat, StrictInt, conint
-from typing import Dict, List
+from typing import Dict
 
+import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
-from hivemind.dht.schema import SchemaValidator, conbytes
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 
+class SampleSchema(BaseModel):
+    experiment_name: bytes
+    n_batches: Dict[bytes, conint(ge=0, strict=True)]
+    signed_data: Dict[BytesWithPublicKey, bytes]
+
+
 @pytest.fixture
 async def dht_nodes_with_schema():
-    class Schema(BaseModel):
-        experiment_name: bytes
-        n_batches: Dict[bytes, conint(ge=0, strict=True)]
-        signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
-
-    validator = SchemaValidator(Schema)
+    validator = SchemaValidator(SampleSchema)
 
     alice = await DHTNode.create(record_validator=validator)
     bob = await DHTNode.create(
@@ -31,17 +33,17 @@ async def test_expecting_regular_value(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Regular value (bytes) expected
-    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', 666, get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10,
+    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
+    assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
                                subkey=b'subkey')
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store(b'experiment_name', [], get_dht_time() + 10)
-    assert not await bob.store(b'experiment_name', [1, 2, 3], get_dht_time() + 10)
+    assert not await bob.store('experiment_name', [], get_dht_time() + 10)
+    assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
 
 
 @pytest.mark.forked
@@ -50,27 +52,27 @@ async def test_expecting_dictionary(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Dictionary (bytes -> non-negative int) expected
-    assert await bob.store(b'n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
-    assert await bob.store(b'n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
-    assert not await bob.store(b'n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10)
-    assert not await bob.store(b'n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
-    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10, subkey=666)
+    assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
+    assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
+    assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', 666, get_dht_time() + 10)
+    assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
+    assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
 
     # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
-    assert not await bob.store(b'n_batches', {b'uid3': 779}, get_dht_time() + 10)
+    assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store(b'n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store(b'n_batches', [], get_dht_time() + 10)
-    assert not await bob.store(b'n_batches', [(b'uid3', 779)], get_dht_time() + 10)
+    assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store('n_batches', [], get_dht_time() + 10)
+    assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
 
     # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
-    assert not await bob.store(b'n_batches', '', get_dht_time() + 10)
+    assert not await bob.store('n_batches', '', get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get(b'n_batches', latest=True)).value
+        dictionary = (await peer.get('n_batches', latest=True)).value
         assert (len(dictionary) == 2 and
                 dictionary[b'uid1'].value == 777 and
                 dictionary[b'uid2'].value == 778)
@@ -83,13 +85,13 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
 
     # Subkeys expected to contain a public key
     # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
-    assert await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+    assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                            subkey=b'uid[owner:public-key]')
-    assert not await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+    assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                                subkey=b'uid-without-public-key')
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get(b'signed_data', latest=True)).value
+        dictionary = (await peer.get('signed_data', latest=True)).value
         assert (len(dictionary) == 1 and
                 dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
 
@@ -111,17 +113,38 @@ async def test_keys_outside_schema(dht_nodes_with_schema):
         bob = await DHTNode.create(
             record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
 
-        store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
+        store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
         assert store_ok == allow_extra_keys
 
         for peer in [alice, bob]:
-            result = await peer.get(b'unknown_key', latest=True)
+            result = await peer.get('unknown_key', latest=True)
             if allow_extra_keys:
                 assert result.value == b'foo_bar'
             else:
                 assert result is None
 
 
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_prefix():
+    class Schema(BaseModel):
+        field: StrictInt
+
+    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
+
+    alice = await DHTNode.create(record_validator=validator)
+    bob = await DHTNode.create(
+        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    assert await bob.store('prefix_field', 777, get_dht_time() + 10)
+    assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
+    assert not await bob.store('field', 777, get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get('prefix_field', latest=True)).value == 777
+        assert (await peer.get('field', latest=True)) is None
+
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_merging_schema_validators(dht_nodes_with_schema):
@@ -147,18 +170,31 @@ async def test_merging_schema_validators(dht_nodes_with_schema):
         for peer in [alice, bob]:
             assert peer.protocol.record_validator.merge_with(new_validator)
 
-    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert await bob.store(b'some_field', 777, get_dht_time() + 10)
-    assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
-    assert await bob.store(b'another_field', 42, get_dht_time() + 10)
-    assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert await bob.store('some_field', 777, get_dht_time() + 10)
+    assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store('another_field', 42, get_dht_time() + 10)
+    assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
 
-    # Unkown keys are allowed since the first schema is created with allow_extra_keys=True
-    assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)
+    # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
+    assert await bob.store('unknown_key', 999, get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
-        assert (await peer.get(b'some_field', latest=True)).value == 777
-        assert (await peer.get(b'another_field', latest=True)).value == 'string_value'
+        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get('some_field', latest=True)).value == 777
+        assert (await peer.get('another_field', latest=True)).value == 'string_value'
+
+        assert (await peer.get('unknown_key', latest=True)).value == 999
+
+
+@pytest.mark.forked
+def test_sending_validator_instance_between_processes():
+    alice = hivemind.DHT(start=True)
+    bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    alice.add_validators([SchemaValidator(SampleSchema)])
+    bob.add_validators([SchemaValidator(SampleSchema)])
 
-        assert (await peer.get(b'unknown_key', latest=True)).value == 999
+    assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert not bob.store('experiment_name', 777, get_dht_time() + 10)
+    assert alice.get('experiment_name', latest=True).value == b'foo_bar'

+ 16 - 16
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ import hivemind
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
-from hivemind.dht.schema import SchemaValidator
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
 
 
@@ -18,7 +18,7 @@ class SchemaA(BaseModel):
 
 
 class SchemaB(BaseModel):
-    field_b: Dict[bytes, StrictInt]
+    field_b: Dict[BytesWithPublicKey, StrictInt]
 
 
 @pytest.fixture
@@ -40,9 +40,9 @@ def test_composite_validator(validators_for_app):
         [SchemaValidator, RSASignatureValidator])
     assert len(validator._validators[0]._schemas) == 2
 
-    public_key = validators_for_app['A'][0].ownership_marker
-    record = DHTRecord(key=DHTID.generate(source=b'field_b').to_bytes(),
-                       subkey=DHTProtocol.serializer.dumps(public_key),
+    local_public_key = validators_for_app['A'][0].local_public_key
+    record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
+                       subkey=DHTProtocol.serializer.dumps(local_public_key),
                        value=DHTProtocol.serializer.dumps(777),
                        expiration_time=hivemind.get_dht_time() + 10)
 
@@ -53,7 +53,7 @@ def test_composite_validator(validators_for_app):
     assert validator.validate(signed_record)
     assert validator.strip_value(signed_record) == record.value
 
-    record = DHTRecord(key=DHTID.generate(source=b'unknown_key').to_bytes(),
+    record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
                        subkey=DHTProtocol.IS_REGULAR_VALUE,
                        value=DHTProtocol.serializer.dumps(777),
                        expiration_time=hivemind.get_dht_time() + 10)
@@ -77,17 +77,17 @@ def test_dht_add_validators(validators_for_app):
     # After starting the process, other apps may add new validators to the existing DHT
     dht.add_validators(validators_for_app['B'])
 
-    assert dht.store(b'field_a', b'bytes_value', hivemind.get_dht_time() + 10)
-    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+    assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
+    assert dht.get('field_a', latest=True).value == b'bytes_value'
 
-    assert not dht.store(b'field_a', 666, hivemind.get_dht_time() + 10)
-    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+    assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
+    assert dht.get('field_a', latest=True).value == b'bytes_value'
 
-    public_key = validators_for_app['A'][0].ownership_marker
-    assert dht.store(b'field_b', 777, hivemind.get_dht_time() + 10, subkey=public_key)
-    dictionary = dht.get(b'field_b', latest=True).value
+    local_public_key = validators_for_app['A'][0].local_public_key
+    assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
+    dictionary = dht.get('field_b', latest=True).value
     assert (len(dictionary) == 1 and
-            dictionary[public_key].value == 777)
+            dictionary[local_public_key].value == 777)
 
-    assert not dht.store(b'unknown_key', 666, hivemind.get_dht_time() + 10)
-    assert dht.get(b'unknown_key', latest=True) is None
+    assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
+    assert dht.get('unknown_key', latest=True) is None