浏览代码

Merge branch 'master' into decentralized_lr_scheduler

justheuristic 4 年之前
父节点
当前提交
13a1dd4e9d

+ 8 - 1
examples/albert/run_trainer.py

@@ -108,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.samples = 0
         self.steps = 0
         self.loss = 0
+        self.total_samples_processed = 0
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
@@ -127,7 +128,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.steps += 1
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
-
+                self.total_samples_processed += self.samples
                 samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
                 statistics = metrics_utils.LocalMetrics(
                     step=self.collaborative_optimizer.local_step,
@@ -135,12 +136,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps)
+                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                if self.steps:
+                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+
                 self.loss = 0
                 self.steps = 0
                 self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
                                subkey=self.local_public_key, value=statistics.dict(),
                                expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
+
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
         return control

+ 27 - 27
hivemind/client/averaging/__init__.py

@@ -171,35 +171,34 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
-
-        async def _run():
-            grpc.aio.init_grpc_aio()
-
-            if self.listen:
-                server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                found_port = server.add_insecure_port(self.listen_on)
-                assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                self._port.value = found_port
-                await server.start()
-            else:
-                logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                grpc.aio.init_grpc_aio()
+
+                if self.listen:
+                    server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                    found_port = server.add_insecure_port(self.listen_on)
+                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
+                    self._port.value = found_port
+                    await server.start()
+                else:
+                    logger.info(f"The averager running in an experimental client mode, please report any bugs.")
 
-            self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
-                                            client_mode=not self.listen)
-            if self.listen:
-                asyncio.create_task(self._declare_for_download_periodically())
+                self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
+                                                client_mode=not self.listen)
+                if self.listen:
+                    asyncio.create_task(self._declare_for_download_periodically())
 
-            self._pending_group_assembled = asyncio.Event()
-            self._pending_group_assembled.set()
-            self.ready.set()
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+                self.ready.set()
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(*args, **kwargs))
 
-        loop.run_until_complete(_run())
+            loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
@@ -255,7 +254,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 try:
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout,
+                                                                        data_for_gather=data_for_gather)
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                     group_id = group_info.group_id
@@ -294,7 +294,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         try:
             weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints,  map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
 
             # compute optimal part sizes from peer throughputs
             incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]

+ 7 - 4
hivemind/client/moe.py

@@ -120,8 +120,11 @@ class RemoteMixtureOfExperts(nn.Module):
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
+
+        device = grid_scores[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -133,11 +136,11 @@ class RemoteMixtureOfExperts(nn.Module):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
 

+ 7 - 4
hivemind/client/switch_moe.py

@@ -156,8 +156,11 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]
+
+        device = grid_probs[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -169,10 +172,10 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores

+ 15 - 17
hivemind/dht/__init__.py

@@ -69,25 +69,23 @@ class DHT(mp.Process):
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         loop = switch_to_uvloop()
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
-        async def _run():
-            node = await DHTNode.create(
-                initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-                num_workers=self.max_workers or 1, record_validator=self._record_validator,
-                **self.kwargs)
-            if node.port is not None:
-                self._port.value = node.port
-            self.ready.set()
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                node = await DHTNode.create(
+                    initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                    num_workers=self.max_workers or 1, record_validator=self._record_validator,
+                    **self.kwargs)
+                if node.port is not None:
+                    self._port.value = node.port
+                self.ready.set()
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
-        try:
-            loop.run_until_complete(_run())
-        except KeyboardInterrupt:
-            logger.debug("Caught KeyboardInterrupt, shutting down")
+            coro = _run()
+            loop.run_until_complete(coro)
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
@@ -96,7 +94,7 @@ class DHT(mp.Process):
         """
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
 
     def shutdown(self) -> None:
         """ Shut down a running dht process """

+ 13 - 49
hivemind/dht/crypto.py

@@ -1,13 +1,10 @@
-import base64
 import dataclasses
 import re
-
-from cryptography.exceptions import InvalidSignature
-from cryptography.hazmat.primitives import hashes, serialization
-from cryptography.hazmat.primitives.asymmetric import padding, rsa
+from typing import Optional
 
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
 
 logger = get_logger(__name__)
@@ -31,26 +28,13 @@ class RSASignatureValidator(RecordValidatorBase):
 
     _cached_private_key = None
 
-    def __init__(self, *, ignore_cached_key=False):
-        if self._cached_private_key is None or ignore_cached_key:
-            # Since generating a private key takes ~100 ms, we cache it for future validator
-            # instances in the same process (unless ignore_cached_key=True)
-            self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
-            if not ignore_cached_key:
-                RSASignatureValidator._cached_private_key = self._private_key
-        else:
-            self._private_key = RSASignatureValidator._cached_private_key
-
-        serialized_public_key = self._private_key.public_key().public_bytes(
-            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
-        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
-
-        self._init_signature_params()
+    def __init__(self, private_key: Optional[RSAPrivateKey]=None):
+        if private_key is None:
+            private_key = RSAPrivateKey.process_wide()
+        self._private_key = private_key
 
-    def _init_signature_params(self) -> None:
-        self._padding = padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
-                                    salt_length=padding.PSS.MAX_LENGTH)
-        self._hash_algorithm = hashes.SHA256()
+        serialized_public_key = private_key.get_public_key().to_bytes()
+        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
 
     @property
     def local_public_key(self) -> bytes:
@@ -66,31 +50,25 @@ class RSASignatureValidator(RecordValidatorBase):
         if len(set(public_keys)) > 1:
             logger.debug(f"Key and subkey can't contain different public keys in {record}")
             return False
-        public_key = serialization.load_ssh_public_key(public_keys[0])
+        public_key = RSAPublicKey.from_bytes(public_keys[0])
 
         signatures = self._SIGNATURE_RE.findall(record.value)
         if len(signatures) != 1:
             logger.debug(f"Record should have exactly one signature in {record}")
             return False
-        signature = base64.b64decode(signatures[0])
+        signature = signatures[0]
 
         stripped_record = dataclasses.replace(record, value=self.strip_value(record))
-        try:
-            # verify() returns None iff the signature is correct
-            public_key.verify(signature, self._serialize_record(stripped_record),
-                              self._padding, self._hash_algorithm)
-            return True
-        except InvalidSignature:
+        if not public_key.verify(self._serialize_record(stripped_record), signature):
             logger.debug(f'Signature is invalid in {record}')
             return False
+        return True
 
     def sign_value(self, record: DHTRecord) -> bytes:
         if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
             return record.value
 
-        signature = self._private_key.sign(self._serialize_record(record),
-                                           self._padding, self._hash_algorithm)
-        signature = base64.b64encode(signature)
+        signature = self._private_key.sign(self._serialize_record(record))
         return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
 
     def strip_value(self, record: DHTRecord) -> bytes:
@@ -112,17 +90,3 @@ class RSASignatureValidator(RecordValidatorBase):
         # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
         # instances of this class) and report successful merge
         return True
-
-    def __getstate__(self):
-        state = self.__dict__.copy()
-        # Serializes the private key to make the class instances picklable
-        state['_private_key'] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption())
-        return state
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
-        self._init_signature_params()

+ 7 - 2
hivemind/dht/protocol.py

@@ -13,6 +13,7 @@ from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
 
 logger = get_logger(__name__)
 
@@ -34,6 +35,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
             listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
             record_validator: Optional[RecordValidatorBase] = None,
+            authorizer: Optional[AuthorizerBase] = None,
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
@@ -54,11 +56,13 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.record_validator = record_validator
+        self.authorizer = authorizer
 
         if listen:  # set up server to process incoming rpc requests
             grpc.aio.init_grpc_aio()
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            dht_grpc.add_DHTServicer_to_server(self, self.server)
+            servicer = AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer)
+            dht_grpc.add_DHTServicer_to_server(servicer, self.server)
 
             self.port = self.server.add_insecure_port(listen_on)
             assert self.port != 0, f"Failed to listen to {listen_on}"
@@ -89,7 +93,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
     def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
-        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        stub = ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
 
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """

+ 3 - 1
hivemind/hivemind_cli/run_server.py

@@ -32,7 +32,9 @@ def main():
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
-    parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all expert operations')
+    parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')

+ 0 - 3
hivemind/optim/collaborative.py

@@ -127,7 +127,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.samples_processed = 0
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -192,7 +191,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.samples_processed += batch_size
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.should_report_progress.set()
 
@@ -235,7 +233,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.update_scheduler()
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
-            logger.info(f"Your current contribution: {self.samples_processed} samples")
 
             return group_info
 

+ 22 - 0
hivemind/proto/auth.proto

@@ -0,0 +1,22 @@
+syntax = "proto3";
+
+message AccessToken {
+    string username = 1;
+    bytes public_key = 2;
+    string expiration_time = 3;
+    bytes signature = 4;
+}
+
+message RequestAuthInfo {
+    AccessToken client_access_token = 1;
+    bytes service_public_key = 2;
+    double time = 3;
+    bytes nonce = 4;
+    bytes signature = 5;
+}
+
+message ResponseAuthInfo {
+    AccessToken service_access_token = 1;
+    bytes nonce = 2;
+    bytes signature = 3;
+}

+ 25 - 20
hivemind/proto/dht.proto

@@ -1,4 +1,5 @@
 syntax = "proto3";
+import "auth.proto";
 
 // this protocol defines how Hivemind nodes form a distributed hash table.
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
@@ -23,35 +24,40 @@ message NodeInfo {
 }
 
 message PingRequest {
-  NodeInfo peer = 1;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
-  bool validate = 2;                   // set to True if sender wants to validate that he is accessible and synchronized
+  RequestAuthInfo auth = 1;
+  NodeInfo peer = 2;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  bool validate = 3;                   // set to True if sender wants to validate that he is accessible and synchronized
 }
 
 message PingResponse {
-  NodeInfo peer = 1;                   // respondent's node id, for you to update routing table
-  string sender_endpoint = 2;          // echo sender's visible endpoint - used to infer his ip address
-  double dht_time = 3;                 // recipient's local DHT time - used to soft-synchronize peers
-  bool available = 4;                  // if validate = True, this flag asserts that the sender is available for ping
+  ResponseAuthInfo auth = 1;
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  string sender_endpoint = 3;          // echo sender's visible endpoint - used to infer his ip address
+  double dht_time = 4;                 // recipient's local DHT time - used to soft-synchronize peers
+  bool available = 5;                  // if validate = True, this flag asserts that the sender is available for ping
 }
 
 message StoreRequest {
+  RequestAuthInfo auth = 1;
   // three lists of the same length representing dht keys, dht values and expiration
-  repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes subkeys = 2;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
-  repeated bytes values = 3;           // serialized value for i-th key
-  repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated bytes keys = 2;             // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes subkeys = 3;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
+  repeated bytes values = 4;           // serialized value for i-th key
+  repeated double expiration_time = 5; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 6;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 7;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 
 message StoreResponse {
-  repeated bool store_ok = 1;          // for every key, True means store accepted, False means store rejected/failed
-  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  ResponseAuthInfo auth = 1;
+  repeated bool store_ok = 2;          // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 3;                   // respondent's node id, for you to update routing table
 }
 
 message FindRequest {
-  repeated bytes keys = 1;             // a list of DHTID search keys encoded as bytes
-  NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
+  RequestAuthInfo auth = 1;
+  repeated bytes keys = 2;             // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 3;                   // optional, same behavior as in DHT.ping
 }
 
 enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;}
@@ -66,9 +72,8 @@ message FindResult {
   repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 
-
 message FindResponse {
-  repeated FindResult results = 1;       // for each item, return value/expiration (if found) and nearest peers
-  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  ResponseAuthInfo auth = 1;
+  repeated FindResult results = 2;       // for each item, return value/expiration (if found) and nearest peers
+  NodeInfo peer = 3;                   // respondent's node id, for you to update routing table
 }
-

+ 30 - 22
hivemind/server/__init__.py

@@ -65,16 +65,20 @@ class Server(threading.Thread):
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
 
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
+                                                       update_period=self.update_period)
+
         if start:
             self.run_in_background(await_ready=True)
 
     @classmethod
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
-               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
-               *, start: bool, **kwargs) -> Server:
+               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
+               max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
+               checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
+               stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,6 +89,7 @@ class Server(threading.Thread):
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
 
@@ -112,9 +117,6 @@ class Server(threading.Thread):
         """
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
-
-        if len(kwargs) != 0:
-            logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
 
         if no_dht:
@@ -172,6 +174,7 @@ class Server(threading.Thread):
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_total_steps=num_total_steps,
                                                          clip_grad_norm=clip_grad_norm,
+                                                         min_batch_size=min_batch_size,
                                                          max_batch_size=max_batch_size)
 
         if checkpoint_dir is not None:
@@ -196,9 +199,7 @@ class Server(threading.Thread):
                 self.dht.run_in_background(await_ready=True)
 
             if self.experts:
-                dht_handler_thread = DHTHandlerThread(
-                    experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
-                dht_handler_thread.start()
+                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 
@@ -207,16 +208,10 @@ class Server(threading.Thread):
                 process.start()
             process.ready.wait()
 
-        self.runtime.run()
-
-        for process in self.conn_handlers:
-            process.join()
-        if self.dht and self.experts:
-            dht_handler_thread.stop.set()
-            dht_handler_thread.join()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
@@ -242,19 +237,32 @@ class Server(threading.Thread):
 
     def shutdown(self):
         """
-        Gracefully terminate a hivemind server, process-safe.
+        Gracefully terminate the server, process-safe.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         self.ready.clear()
+
         for process in self.conn_handlers:
             process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
 
         if self.dht is not None:
             self.dht.shutdown()
             self.dht.join()
 
-        self.runtime.shutdown()
+        logger.debug(f"Shutting down runtime")
+        self.runtime.stop.set()
+        logger.info("Server shutdown succesfully")
 
 
 @contextmanager

+ 4 - 1
hivemind/server/connection_handler.py

@@ -52,7 +52,10 @@ class ConnectionHandler(mp.context.ForkProcess):
             await server.wait_for_termination()
             logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
 
-        loop.run_until_complete(_run())
+        try:
+            loop.run_until_complete(_run())
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))

+ 2 - 2
hivemind/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
-        self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
-        self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
+        self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
+        self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)
 
         self.update_count = 0
         self.examples_processed = 0

+ 2 - 2
hivemind/server/expert_uid.py

@@ -62,8 +62,8 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                 else:
                     raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt as e:
-                raise e
+            except KeyboardInterrupt:
+                raise
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
         return UID_DELIMITER.join(uid)

+ 17 - 20
hivemind/server/runtime.py

@@ -48,8 +48,8 @@ class Runtime(threading.Thread):
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
-        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+        self.stop = threading.Event()
 
         self.stats_report_interval = stats_report_interval
         if self.stats_report_interval is not None:
@@ -72,62 +72,59 @@ class Runtime(threading.Thread):
 
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
-                    logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
                     start = time()
                     outputs = pool.process_func(*batch)
                     batch_processing_time = time() - start
 
                     batch_size = outputs[0].size(0)
-                    logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
+                    logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
 
                     if self.stats_report_interval is not None:
-                        self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+                        self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
-                logger.info("Shutting down")
-
-                if self.stats_report_interval is not None:
-                    self.stats_reporter.stop.set()
-                    self.stats_reporter.join()
-
                 self.shutdown()
 
-    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
-
     def shutdown(self):
         """ Gracefully terminate a running runtime. """
-        self.ready.clear()
-        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)  # trigger background thread to shutdown
+        logger.info("Shutting down")
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        self.stop.set()  # trigger background thread to shutdown
+
+        logger.debug("Terminating pools")
         for pool in self.pools:
             if pool.is_alive():
                 pool.terminate()
                 pool.join()
+        logger.debug("Pools terminated")
 
     def iterate_minibatches_from_pools(self, timeout=None):
         """
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         """
         with DefaultSelector() as selector:
-            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
 
-            while True:
+            while not self.stop.is_set():
                 # wait until at least one batch_receiver becomes available
                 logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
-                if self.SHUTDOWN_TRIGGER in ready_objects:
-                    break  # someone asked us to shutdown, break from the loop
 
                 logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
 
-                logger.debug(f"Loading batch from {pool.uid}")
+                logger.debug(f"Loading batch from {pool.name}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
-                logger.debug(f"Loaded batch from {pool.uid}")
+                logger.debug(f"Loaded batch from {pool.name}")
                 yield pool, batch_index, batch_tensors
 
 

+ 60 - 70
hivemind/server/task_pool.py

@@ -6,7 +6,6 @@ import multiprocessing as mp
 import os
 import threading
 import time
-import uuid
 from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from concurrent.futures import Future
@@ -24,8 +23,8 @@ Task = namedtuple("Task", ("future", "args"))
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
-    def __init__(self, process_func: callable, daemon=True):
-        super().__init__(daemon=daemon)
+    def __init__(self, process_func: callable, daemon=True, **kwargs):
+        super().__init__(daemon=daemon, **kwargs)
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
@@ -63,19 +62,18 @@ class TaskPool(TaskPoolBase):
     :param process_func: function to be applied to every formed batch; called by Runtime
         Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+    :param name: pool name
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param timeout: wait for a subsequent task for at most this many seconds
     :param pool_size: store at most this many unprocessed tasks in a queue
     :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
-    :param uid: pool identifier used for shared array allocation
     :param start: if True, start automatically at the end of __init__
     """
 
-    def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
-                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
-        super().__init__(process_func, daemon=daemon)
+    def __init__(self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1,
+                 timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False):
+        super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
-        self.uid = uid or uuid.uuid4()
         self.prefetch_batches = prefetch_batches
 
         # interaction with ConnectionHandlers
@@ -112,7 +110,7 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 total_size = 0
             try:
-                logger.debug(f"{self.uid} getting next task")
+                logger.debug(f"{self.name} getting next task")
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
@@ -134,80 +132,72 @@ class TaskPool(TaskPoolBase):
 
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
-        logger.info(f'{self.uid} starting, pid={os.getpid()}')
+        logger.info(f'{self.name} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
+
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.uid}_output')
+                                         name=f'{self.name}_output')
+
         try:
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
-        except BaseException as e:
-            # terminate output loop
-            self.outputs_sender.send(e)
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
+        finally:
             output_thread.join()
-            raise e
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
-        try:
-            prev_num_tasks = 0  # number of tasks currently in shared buffer
-            batch_index = max(pending_batches.keys(), default=0)
-            batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-            while True:
-                # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-                # assumes that tasks are processed in the same order as they are created
-                for skip_i in range(prev_num_tasks):
-                    finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
-                    if skip_i == prev_num_tasks - 1:
-                        self.priority = finished_task_timestamp
-
-                logger.debug(f"{self.uid} getting next batch")
-                batch_tasks = next(batch_iterator)
-                # save batch futures, _output_loop will deliver on them later
-                pending_batches[batch_index] = batch_tasks
-
-                logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
-                # find or create shared arrays for current batch size
-                batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
-                                range(len(batch_tasks[0].args))]
-                batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-                logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
-                self.batch_sender.send((batch_index, batch_inputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
-                prev_num_tasks = len(batch_tasks)
-                batch_index += 1
-        except KeyboardInterrupt:
-            logger.debug('Caught KeyboardInterrupt, shutting down')
+
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+        while True:
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = finished_task_timestamp
+
+            logger.debug(f"{self.name} getting next batch")
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            # find or create shared arrays for current batch size
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
+                            range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
+            self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
-        try:
-            while True:
-                logger.debug(f"{self.uid} waiting for results from runtime")
-                payload = self.outputs_receiver.recv()
-                if isinstance(payload, BaseException):
-                    raise payload
-                else:
-                    batch_index, batch_outputs = payload
-                logger.debug(f"{self.uid}, batch {batch_index}: got results")
-
-                # split batch into partitions for individual tasks
-                batch_tasks = pending_batches.pop(batch_index)
-                task_sizes = [self.get_task_size(task) for task in batch_tasks]
-                outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
-
-                # dispatch results to futures
-                for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                    try:
-                        task.future.set_result(tuple(task_outputs))
-                    except FutureStateError as e:
-                        logger.debug(f"Failed to send task result due to an exception: {e}")
-        except KeyboardInterrupt:
-            logger.debug(f"Caught KeyboardInterrupt, shutting down")
+        while True:
+            logger.debug(f"{self.name} waiting for results from runtime")
+            batch_index, batch_outputs = self.outputs_receiver.recv()
+            logger.debug(f"{self.name}, batch {batch_index}: got results")
+
+            # split batch into partitions for individual tasks
+            batch_tasks = pending_batches.pop(batch_index)
+            task_sizes = [self.get_task_size(task) for task in batch_tasks]
+            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
+
+            # dispatch results to futures
+            for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                try:
+                    task.future.set_result(tuple(task_outputs))
+                except FutureStateError as e:
+                    logger.debug(f"Failed to send task result due to an exception: {e}")
 
     @property
     def empty(self):

+ 215 - 0
hivemind/utils/auth.py

@@ -0,0 +1,215 @@
+import asyncio
+import functools
+import secrets
+import threading
+import time
+from abc import ABC, abstractmethod
+from enum import Enum
+from datetime import timedelta
+from typing import Optional, Tuple
+
+import grpc
+
+from hivemind.proto.auth_pb2 import AccessToken, RequestAuthInfo, ResponseAuthInfo
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.logging import get_logger
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time
+
+
+logger = get_logger(__name__)
+
+
+class AuthorizedRequestBase:
+    """
+    Interface for protobufs with the ``RequestAuthInfo auth`` field. Used for type annotations only.
+    """
+
+    auth: RequestAuthInfo
+
+
+class AuthorizedResponseBase:
+    """
+    Interface for protobufs with the ``ResponseAuthInfo auth`` field. Used for type annotations only.
+    """
+
+    auth: ResponseAuthInfo
+
+
+class AuthorizerBase(ABC):
+    @abstractmethod
+    async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
+        ...
+
+    @abstractmethod
+    async def validate_request(self, request: AuthorizedRequestBase) -> bool:
+        ...
+
+    @abstractmethod
+    async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
+        ...
+
+    @abstractmethod
+    async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
+        ...
+
+
+class TokenAuthorizerBase(AuthorizerBase):
+    """
+    Implements the authorization protocol for a moderated Hivemind network.
+    See https://github.com/learning-at-home/hivemind/issues/253
+    """
+
+    def __init__(self, local_private_key: Optional[RSAPrivateKey]=None):
+        if local_private_key is None:
+            local_private_key = RSAPrivateKey.process_wide()
+        self._local_private_key = local_private_key
+        self._local_public_key = local_private_key.get_public_key()
+
+        self._local_access_token = None
+        self._refresh_lock = asyncio.Lock()
+
+        self._recent_nonces = TimedStorage()
+
+    @abstractmethod
+    async def get_token(self) -> AccessToken:
+        ...
+
+    @abstractmethod
+    def is_token_valid(self, access_token: AccessToken) -> bool:
+        ...
+
+    @abstractmethod
+    def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
+        ...
+
+    async def refresh_token_if_needed(self) -> None:
+        if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
+            async with self._refresh_lock:
+                if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
+                    self._local_access_token = await self.get_token()
+                    assert self.is_token_valid(self._local_access_token)
+
+    @property
+    def local_public_key(self) -> RSAPublicKey:
+        return self._local_public_key
+
+    async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
+        await self.refresh_token_if_needed()
+        auth = request.auth
+
+        auth.client_access_token.CopyFrom(self._local_access_token)
+
+        if service_public_key is not None:
+            auth.service_public_key = service_public_key.to_bytes()
+        auth.time = get_dht_time()
+        auth.nonce = secrets.token_bytes(8)
+
+        assert auth.signature == b''
+        auth.signature = self._local_private_key.sign(request.SerializeToString())
+
+    _MAX_CLIENT_SERVICER_TIME_DIFF = timedelta(minutes=1)
+
+    async def validate_request(self, request: AuthorizedRequestBase) -> bool:
+        await self.refresh_token_if_needed()
+        auth = request.auth
+
+        if not self.is_token_valid(auth.client_access_token):
+            logger.debug('Client failed to prove that it (still) has access to the network')
+            return False
+
+        client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
+        signature = auth.signature
+        auth.signature = b''
+        if not client_public_key.verify(request.SerializeToString(), signature):
+            logger.debug('Request has invalid signature')
+            return False
+
+        if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
+            logger.debug('Request is generated for a peer with another public key')
+            return False
+
+        with self._recent_nonces.freeze():
+            current_time = get_dht_time()
+            if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
+                logger.debug('Clocks are not synchronized or a previous request is replayed again')
+                return False
+            if auth.nonce in self._recent_nonces:
+                logger.debug('Previous request is replayed again')
+                return False
+
+        self._recent_nonces.store(auth.nonce, None,
+                                  current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3)
+        return True
+
+    async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
+        await self.refresh_token_if_needed()
+        auth = response.auth
+
+        auth.service_access_token.CopyFrom(self._local_access_token)
+        auth.nonce = request.auth.nonce
+
+        assert auth.signature == b''
+        auth.signature = self._local_private_key.sign(response.SerializeToString())
+
+    async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
+        await self.refresh_token_if_needed()
+        auth = response.auth
+
+        if not self.is_token_valid(auth.service_access_token):
+            logger.debug('Service failed to prove that it (still) has access to the network')
+            return False
+
+        service_public_key = RSAPublicKey.from_bytes(auth.service_access_token.public_key)
+        signature = auth.signature
+        auth.signature = b''
+        if not service_public_key.verify(response.SerializeToString(), signature):
+            logger.debug('Response has invalid signature')
+            return False
+
+        if auth.nonce != request.auth.nonce:
+            logger.debug('Response is generated for another request')
+            return False
+
+        return True
+
+
+class AuthRole(Enum):
+    CLIENT = 0
+    SERVICER = 1
+
+
+class AuthRPCWrapper:
+    def __init__(self, stub, role: AuthRole,
+                 authorizer: Optional[AuthorizerBase], service_public_key: Optional[RSAPublicKey]=None):
+        self._stub = stub
+        self._role = role
+        self._authorizer = authorizer
+        self._service_public_key = service_public_key
+
+    def __getattribute__(self, name: str):
+        if not name.startswith('rpc_'):
+            return object.__getattribute__(self, name)
+
+        method = getattr(self._stub, name)
+
+        @functools.wraps(method)
+        async def wrapped_rpc(request: AuthorizedRequestBase, *args, **kwargs):
+            if self._authorizer is not None:
+                if self._role == AuthRole.CLIENT:
+                    await self._authorizer.sign_request(request, self._service_public_key)
+                elif self._role == AuthRole.SERVICER:
+                    if not await self._authorizer.validate_request(request):
+                        return None
+
+            response = await method(request, *args, **kwargs)
+
+            if self._authorizer is not None:
+                if self._role == AuthRole.SERVICER:
+                    await self._authorizer.sign_response(response, request)
+                elif self._role == AuthRole.CLIENT:
+                    if not await self._authorizer.validate_response(response, request):
+                        return None
+
+            return response
+
+        return wrapped_rpc

+ 102 - 0
hivemind/utils/crypto.py

@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+import base64
+import threading
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from cryptography import exceptions
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+
+class PrivateKey(ABC):
+    @abstractmethod
+    def sign(self, data: bytes) -> bytes:
+        ...
+
+    @abstractmethod
+    def get_public_key(self) -> PublicKey:
+        ...
+
+
+class PublicKey(ABC):
+    @abstractmethod
+    def verify(self, data: bytes, signature: bytes) -> bool:
+        ...
+
+    @abstractmethod
+    def to_bytes(self) -> bytes:
+        ...
+
+    @classmethod
+    @abstractmethod
+    def from_bytes(cls, key: bytes) -> bytes:
+        ...
+
+
+_RSA_PADDING = padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH)
+_RSA_HASH_ALGORITHM = hashes.SHA256()
+
+
+class RSAPrivateKey(PrivateKey):
+    def __init__(self):
+        self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+
+    _process_wide_key = None
+    _process_wide_key_lock = threading.RLock()
+
+    @classmethod
+    def process_wide(cls) -> RSAPrivateKey:
+        if cls._process_wide_key is None:
+            with cls._process_wide_key_lock:
+                if cls._process_wide_key is None:
+                    cls._process_wide_key = cls()
+        return cls._process_wide_key
+
+    def sign(self, data: bytes) -> bytes:
+        signature = self._private_key.sign(data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
+        return base64.b64encode(signature)
+
+    def get_public_key(self) -> RSAPublicKey:
+        return RSAPublicKey(self._private_key.public_key())
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Serializes the private key to make the class instances picklable
+        state['_private_key'] = self._private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.OpenSSH,
+            encryption_algorithm=serialization.NoEncryption())
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+
+
+class RSAPublicKey(PublicKey):
+    def __init__(self, public_key: rsa.RSAPublicKey):
+        self._public_key = public_key
+
+    def verify(self, data: bytes, signature: bytes) -> bool:
+        try:
+            signature = base64.b64decode(signature)
+
+            # Returns None if the signature is correct, raises an exception otherwise
+            self._public_key.verify(signature, data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
+
+            return True
+        except (ValueError, exceptions.InvalidSignature):
+            return False
+
+    def to_bytes(self) -> bytes:
+        return self._public_key.public_bytes(
+            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
+
+    @classmethod
+    def from_bytes(cls, key: bytes) -> RSAPublicKey:
+        key = serialization.load_ssh_public_key(key)
+        if not isinstance(key, rsa.RSAPublicKey):
+            raise ValueError(f'Expected an RSA public key, got {key}')
+        return cls(key)

+ 161 - 0
tests/test_auth.py

@@ -0,0 +1,161 @@
+from datetime import datetime, timedelta
+from typing import Optional, Tuple
+
+import pytest
+
+from hivemind.proto import dht_pb2
+from hivemind.proto.auth_pb2 import AccessToken
+from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class MockAuthorizer(TokenAuthorizerBase):
+    _authority_private_key = None
+    _authority_public_key = None
+
+    def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str='mock'):
+        super().__init__(local_private_key)
+
+        self._username = username
+        self._authority_public_key = None
+
+    async def get_token(self) -> AccessToken:
+        if MockAuthorizer._authority_private_key is None:
+            MockAuthorizer._authority_private_key = RSAPrivateKey()
+
+        self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()
+
+        token = AccessToken(username=self._username,
+                            public_key=self.local_public_key.to_bytes(),
+                            expiration_time=str(datetime.utcnow() + timedelta(minutes=1)))
+        token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token))
+        return token
+
+    def is_token_valid(self, access_token: AccessToken) -> bool:
+        data = self._token_to_bytes(access_token)
+        if not self._authority_public_key.verify(data, access_token.signature):
+            logger.exception('Access token has invalid signature')
+            return False
+
+        try:
+            expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        except ValueError:
+            logger.exception(
+                f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time.tzinfo is not None:
+            logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time < datetime.utcnow():
+            logger.exception('Access token has expired')
+            return False
+
+        return True
+
+    _MAX_LATENCY = timedelta(minutes=1)
+
+    def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
+        expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        return expiration_time < datetime.utcnow() + self._MAX_LATENCY
+
+    @staticmethod
+    def _token_to_bytes(access_token: AccessToken) -> bytes:
+        return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
+
+
+@pytest.mark.asyncio
+async def test_valid_request_and_response():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+    assert await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+    assert await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_invalid_access_token():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+
+    # Break the access token signature
+    request.auth.client_access_token.signature = b'broken'
+
+    assert not await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+
+    # Break the access token signature
+    response.auth.service_access_token.signature = b'broken'
+
+    assert not await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_invalid_signatures():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+
+    # A man-in-the-middle attacker changes the request content
+    request.peer.endpoint = '127.0.0.2:7777'
+
+    assert not await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+
+    # A man-in-the-middle attacker changes the response content
+    response.sender_endpoint = '127.0.0.2:31337'
+
+    assert not await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_auth_rpc_wrapper():
+    class Servicer:
+        async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
+            assert request.peer.endpoint == '127.0.0.1:1111'
+            assert request.auth.client_access_token.username == 'alice'
+
+            response = dht_pb2.PingResponse()
+            response.sender_endpoint = '127.0.0.1:2222'
+            return response
+
+    class Client:
+        def __init__(self, servicer: Servicer):
+            self._servicer = servicer
+
+        async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
+            return await self._servicer.rpc_increment(request)
+
+    servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), 'bob'))
+    client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice'))
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:1111'
+
+    response = await client.rpc_increment(request)
+
+    assert response.sender_endpoint == '127.0.0.1:2222'
+    assert response.auth.service_access_token.username == 'bob'

+ 6 - 5
tests/test_dht_crypto.py

@@ -9,12 +9,13 @@ from hivemind.dht import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import LOCALHOST
 from hivemind.dht.validation import DHTRecord
+from hivemind.utils.crypto import RSAPrivateKey
 
 
 def test_rsa_signature_validator():
     receiver_validator = RSASignatureValidator()
-    sender_validator = RSASignatureValidator(ignore_cached_key=True)
-    mallory_validator = RSASignatureValidator(ignore_cached_key=True)
+    sender_validator = RSASignatureValidator(RSAPrivateKey())
+    mallory_validator = RSASignatureValidator(RSAPrivateKey())
 
     plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
                              expiration_time=get_dht_time() + 10)
@@ -52,7 +53,7 @@ def test_cached_key():
     second_validator = RSASignatureValidator()
     assert first_validator.local_public_key == second_validator.local_public_key
 
-    third_validator = RSASignatureValidator(ignore_cached_key=True)
+    third_validator = RSASignatureValidator(RSAPrivateKey())
     assert first_validator.local_public_key != third_validator.local_public_key
 
 
@@ -105,10 +106,10 @@ def test_signing_in_different_process():
 async def test_dhtnode_signatures():
     alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
     bob = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
     mallory = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
 
     key = b'key'