Browse Source

Merge branch 'master' into decentralized_lr_scheduler

justheuristic 4 years ago
parent
commit
13a1dd4e9d

+ 8 - 1
examples/albert/run_trainer.py

@@ -108,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.samples = 0
         self.samples = 0
         self.steps = 0
         self.steps = 0
         self.loss = 0
         self.loss = 0
+        self.total_samples_processed = 0
 
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
                        control: transformers.TrainerControl, **kwargs):
@@ -127,7 +128,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.steps += 1
             self.steps += 1
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
-
+                self.total_samples_processed += self.samples
                 samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
                 samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
                 statistics = metrics_utils.LocalMetrics(
                 statistics = metrics_utils.LocalMetrics(
                     step=self.collaborative_optimizer.local_step,
                     step=self.collaborative_optimizer.local_step,
@@ -135,12 +136,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     samples_accumulated=self.samples,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     loss=self.loss,
                     mini_steps=self.steps)
                     mini_steps=self.steps)
+                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                if self.steps:
+                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+
                 self.loss = 0
                 self.loss = 0
                 self.steps = 0
                 self.steps = 0
                 self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
                 self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
                                subkey=self.local_public_key, value=statistics.dict(),
                                subkey=self.local_public_key, value=statistics.dict(),
                                expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
                                return_future=True)
+
         self.samples = self.collaborative_optimizer.local_samples_accumulated
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
 
         return control
         return control

+ 27 - 27
hivemind/client/averaging/__init__.py

@@ -171,35 +171,34 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
-
-        async def _run():
-            grpc.aio.init_grpc_aio()
-
-            if self.listen:
-                server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                found_port = server.add_insecure_port(self.listen_on)
-                assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                self._port.value = found_port
-                await server.start()
-            else:
-                logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                grpc.aio.init_grpc_aio()
+
+                if self.listen:
+                    server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                    found_port = server.add_insecure_port(self.listen_on)
+                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
+                    self._port.value = found_port
+                    await server.start()
+                else:
+                    logger.info(f"The averager running in an experimental client mode, please report any bugs.")
 
 
-            self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
-                                            client_mode=not self.listen)
-            if self.listen:
-                asyncio.create_task(self._declare_for_download_periodically())
+                self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
+                                                client_mode=not self.listen)
+                if self.listen:
+                    asyncio.create_task(self._declare_for_download_periodically())
 
 
-            self._pending_group_assembled = asyncio.Event()
-            self._pending_group_assembled.set()
-            self.ready.set()
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+                self.ready.set()
 
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(*args, **kwargs))
 
 
-        loop.run_until_complete(_run())
+            loop.run_until_complete(_run())
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -255,7 +254,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
                     data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout,
+                                                                        data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
                     group_id = group_info.group_id
                     group_id = group_info.group_id
@@ -294,7 +294,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         try:
         try:
             weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints,  map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
 
 
             # compute optimal part sizes from peer throughputs
             # compute optimal part sizes from peer throughputs
             incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
             incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]

+ 7 - 4
hivemind/client/moe.py

@@ -120,8 +120,11 @@ class RemoteMixtureOfExperts(nn.Module):
         batch_size = len(batch_experts)
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
+
+        device = grid_scores[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -133,11 +136,11 @@ class RemoteMixtureOfExperts(nn.Module):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
         return scores
 
 

+ 7 - 4
hivemind/client/switch_moe.py

@@ -156,8 +156,11 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         batch_size = len(batch_experts)
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         max_num_experts = max(expert_counts)
         total_num_experts = sum(expert_counts)
         total_num_experts = sum(expert_counts)
-        expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
-        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]
+
+        device = grid_probs[0].device
+
+        expert_index_in_batch = torch.arange(total_num_experts, device=device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
         flat_experts = [expert for row in batch_experts for expert in row]
@@ -169,10 +172,10 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
             for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
             for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
 
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
         return scores

+ 15 - 17
hivemind/dht/__init__.py

@@ -69,25 +69,23 @@ class DHT(mp.Process):
     def run(self) -> None:
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
-        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
 
-        async def _run():
-            node = await DHTNode.create(
-                initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-                num_workers=self.max_workers or 1, record_validator=self._record_validator,
-                **self.kwargs)
-            if node.port is not None:
-                self._port.value = node.port
-            self.ready.set()
+        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+            async def _run():
+                node = await DHTNode.create(
+                    initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                    num_workers=self.max_workers or 1, record_validator=self._record_validator,
+                    **self.kwargs)
+                if node.port is not None:
+                    self._port.value = node.port
+                self.ready.set()
 
 
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
+                while True:
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                    asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
 
-        try:
-            loop.run_until_complete(_run())
-        except KeyboardInterrupt:
-            logger.debug("Caught KeyboardInterrupt, shutting down")
+            coro = _run()
+            loop.run_until_complete(coro)
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -96,7 +94,7 @@ class DHT(mp.Process):
         """
         """
         self.start()
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shut down a running dht process """
         """ Shut down a running dht process """

+ 13 - 49
hivemind/dht/crypto.py

@@ -1,13 +1,10 @@
-import base64
 import dataclasses
 import dataclasses
 import re
 import re
-
-from cryptography.exceptions import InvalidSignature
-from cryptography.hazmat.primitives import hashes, serialization
-from cryptography.hazmat.primitives.asymmetric import padding, rsa
+from typing import Optional
 
 
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
 
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -31,26 +28,13 @@ class RSASignatureValidator(RecordValidatorBase):
 
 
     _cached_private_key = None
     _cached_private_key = None
 
 
-    def __init__(self, *, ignore_cached_key=False):
-        if self._cached_private_key is None or ignore_cached_key:
-            # Since generating a private key takes ~100 ms, we cache it for future validator
-            # instances in the same process (unless ignore_cached_key=True)
-            self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
-            if not ignore_cached_key:
-                RSASignatureValidator._cached_private_key = self._private_key
-        else:
-            self._private_key = RSASignatureValidator._cached_private_key
-
-        serialized_public_key = self._private_key.public_key().public_bytes(
-            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
-        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
-
-        self._init_signature_params()
+    def __init__(self, private_key: Optional[RSAPrivateKey]=None):
+        if private_key is None:
+            private_key = RSAPrivateKey.process_wide()
+        self._private_key = private_key
 
 
-    def _init_signature_params(self) -> None:
-        self._padding = padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
-                                    salt_length=padding.PSS.MAX_LENGTH)
-        self._hash_algorithm = hashes.SHA256()
+        serialized_public_key = private_key.get_public_key().to_bytes()
+        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
 
 
     @property
     @property
     def local_public_key(self) -> bytes:
     def local_public_key(self) -> bytes:
@@ -66,31 +50,25 @@ class RSASignatureValidator(RecordValidatorBase):
         if len(set(public_keys)) > 1:
         if len(set(public_keys)) > 1:
             logger.debug(f"Key and subkey can't contain different public keys in {record}")
             logger.debug(f"Key and subkey can't contain different public keys in {record}")
             return False
             return False
-        public_key = serialization.load_ssh_public_key(public_keys[0])
+        public_key = RSAPublicKey.from_bytes(public_keys[0])
 
 
         signatures = self._SIGNATURE_RE.findall(record.value)
         signatures = self._SIGNATURE_RE.findall(record.value)
         if len(signatures) != 1:
         if len(signatures) != 1:
             logger.debug(f"Record should have exactly one signature in {record}")
             logger.debug(f"Record should have exactly one signature in {record}")
             return False
             return False
-        signature = base64.b64decode(signatures[0])
+        signature = signatures[0]
 
 
         stripped_record = dataclasses.replace(record, value=self.strip_value(record))
         stripped_record = dataclasses.replace(record, value=self.strip_value(record))
-        try:
-            # verify() returns None iff the signature is correct
-            public_key.verify(signature, self._serialize_record(stripped_record),
-                              self._padding, self._hash_algorithm)
-            return True
-        except InvalidSignature:
+        if not public_key.verify(self._serialize_record(stripped_record), signature):
             logger.debug(f'Signature is invalid in {record}')
             logger.debug(f'Signature is invalid in {record}')
             return False
             return False
+        return True
 
 
     def sign_value(self, record: DHTRecord) -> bytes:
     def sign_value(self, record: DHTRecord) -> bytes:
         if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
         if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
             return record.value
             return record.value
 
 
-        signature = self._private_key.sign(self._serialize_record(record),
-                                           self._padding, self._hash_algorithm)
-        signature = base64.b64encode(signature)
+        signature = self._private_key.sign(self._serialize_record(record))
         return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
         return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
 
 
     def strip_value(self, record: DHTRecord) -> bytes:
     def strip_value(self, record: DHTRecord) -> bytes:
@@ -112,17 +90,3 @@ class RSASignatureValidator(RecordValidatorBase):
         # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
         # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
         # instances of this class) and report successful merge
         # instances of this class) and report successful merge
         return True
         return True
-
-    def __getstate__(self):
-        state = self.__dict__.copy()
-        # Serializes the private key to make the class instances picklable
-        state['_private_key'] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption())
-        return state
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
-        self._init_signature_params()

+ 7 - 2
hivemind/dht/protocol.py

@@ -13,6 +13,7 @@ from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -34,6 +35,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
             listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
             listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
             record_validator: Optional[RecordValidatorBase] = None,
             record_validator: Optional[RecordValidatorBase] = None,
+            authorizer: Optional[AuthorizerBase] = None,
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
@@ -54,11 +56,13 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.record_validator = record_validator
         self.record_validator = record_validator
+        self.authorizer = authorizer
 
 
         if listen:  # set up server to process incoming rpc requests
         if listen:  # set up server to process incoming rpc requests
             grpc.aio.init_grpc_aio()
             grpc.aio.init_grpc_aio()
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            dht_grpc.add_DHTServicer_to_server(self, self.server)
+            servicer = AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer)
+            dht_grpc.add_DHTServicer_to_server(servicer, self.server)
 
 
             self.port = self.server.add_insecure_port(listen_on)
             self.port = self.server.add_insecure_port(listen_on)
             assert self.port != 0, f"Failed to listen to {listen_on}"
             assert self.port != 0, f"Failed to listen to {listen_on}"
@@ -89,7 +93,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
 
     def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
     def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
         """ get a DHTStub that sends requests to a given peer """
-        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        stub = ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
 
 
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         """

+ 3 - 1
hivemind/hivemind_cli/run_server.py

@@ -32,7 +32,9 @@ def main():
 
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
                         help='server will use this many processes to handle incoming requests')
-    parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all expert operations')
+    parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
                         help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')

+ 0 - 3
hivemind/optim/collaborative.py

@@ -127,7 +127,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.training_progress_key = f"{self.prefix}_progress"
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.samples_processed = 0
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
         self.last_step_time = None
 
 
@@ -192,7 +191,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
             self.local_steps_accumulated += 1
-            self.samples_processed += batch_size
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
@@ -235,7 +233,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.update_scheduler()
             self.update_scheduler()
 
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
             logger.log(self.status_loglevel, f"Optimizer step: done!")
-            logger.info(f"Your current contribution: {self.samples_processed} samples")
 
 
             return group_info
             return group_info
 
 

+ 22 - 0
hivemind/proto/auth.proto

@@ -0,0 +1,22 @@
+syntax = "proto3";
+
+message AccessToken {
+    string username = 1;
+    bytes public_key = 2;
+    string expiration_time = 3;
+    bytes signature = 4;
+}
+
+message RequestAuthInfo {
+    AccessToken client_access_token = 1;
+    bytes service_public_key = 2;
+    double time = 3;
+    bytes nonce = 4;
+    bytes signature = 5;
+}
+
+message ResponseAuthInfo {
+    AccessToken service_access_token = 1;
+    bytes nonce = 2;
+    bytes signature = 3;
+}

+ 25 - 20
hivemind/proto/dht.proto

@@ -1,4 +1,5 @@
 syntax = "proto3";
 syntax = "proto3";
+import "auth.proto";
 
 
 // this protocol defines how Hivemind nodes form a distributed hash table.
 // this protocol defines how Hivemind nodes form a distributed hash table.
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
@@ -23,35 +24,40 @@ message NodeInfo {
 }
 }
 
 
 message PingRequest {
 message PingRequest {
-  NodeInfo peer = 1;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
-  bool validate = 2;                   // set to True if sender wants to validate that he is accessible and synchronized
+  RequestAuthInfo auth = 1;
+  NodeInfo peer = 2;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  bool validate = 3;                   // set to True if sender wants to validate that he is accessible and synchronized
 }
 }
 
 
 message PingResponse {
 message PingResponse {
-  NodeInfo peer = 1;                   // respondent's node id, for you to update routing table
-  string sender_endpoint = 2;          // echo sender's visible endpoint - used to infer his ip address
-  double dht_time = 3;                 // recipient's local DHT time - used to soft-synchronize peers
-  bool available = 4;                  // if validate = True, this flag asserts that the sender is available for ping
+  ResponseAuthInfo auth = 1;
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  string sender_endpoint = 3;          // echo sender's visible endpoint - used to infer his ip address
+  double dht_time = 4;                 // recipient's local DHT time - used to soft-synchronize peers
+  bool available = 5;                  // if validate = True, this flag asserts that the sender is available for ping
 }
 }
 
 
 message StoreRequest {
 message StoreRequest {
+  RequestAuthInfo auth = 1;
   // three lists of the same length representing dht keys, dht values and expiration
   // three lists of the same length representing dht keys, dht values and expiration
-  repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes subkeys = 2;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
-  repeated bytes values = 3;           // serialized value for i-th key
-  repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated bytes keys = 2;             // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes subkeys = 3;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
+  repeated bytes values = 4;           // serialized value for i-th key
+  repeated double expiration_time = 5; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 6;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 7;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 }
 
 
 message StoreResponse {
 message StoreResponse {
-  repeated bool store_ok = 1;          // for every key, True means store accepted, False means store rejected/failed
-  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  ResponseAuthInfo auth = 1;
+  repeated bool store_ok = 2;          // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 3;                   // respondent's node id, for you to update routing table
 }
 }
 
 
 message FindRequest {
 message FindRequest {
-  repeated bytes keys = 1;             // a list of DHTID search keys encoded as bytes
-  NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
+  RequestAuthInfo auth = 1;
+  repeated bytes keys = 2;             // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 3;                   // optional, same behavior as in DHT.ping
 }
 }
 
 
 enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;}
 enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;}
@@ -66,9 +72,8 @@ message FindResult {
   repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
   repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 }
 
 
-
 message FindResponse {
 message FindResponse {
-  repeated FindResult results = 1;       // for each item, return value/expiration (if found) and nearest peers
-  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
+  ResponseAuthInfo auth = 1;
+  repeated FindResult results = 2;       // for each item, return value/expiration (if found) and nearest peers
+  NodeInfo peer = 3;                   // respondent's node id, for you to update routing table
 }
 }
-

+ 30 - 22
hivemind/server/__init__.py

@@ -65,16 +65,20 @@ class Server(threading.Thread):
             self.checkpoint_saver = None
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
         self.runtime = Runtime(self.experts, **kwargs)
 
 
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
+                                                       update_period=self.update_period)
+
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
     @classmethod
     @classmethod
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
-               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
-               *, start: bool, **kwargs) -> Server:
+               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
+               max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
+               checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
+               stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,6 +89,7 @@ class Server(threading.Thread):
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param hidden_dim: main dimension for expert_cls
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
 
 
@@ -112,9 +117,6 @@ class Server(threading.Thread):
         """
         """
         if custom_module_path is not None:
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
             add_custom_models_from_file(custom_module_path)
-
-        if len(kwargs) != 0:
-            logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
 
 
         if no_dht:
         if no_dht:
@@ -172,6 +174,7 @@ class Server(threading.Thread):
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_warmup_steps=num_warmup_steps,
                                                          num_total_steps=num_total_steps,
                                                          num_total_steps=num_total_steps,
                                                          clip_grad_norm=clip_grad_norm,
                                                          clip_grad_norm=clip_grad_norm,
+                                                         min_batch_size=min_batch_size,
                                                          max_batch_size=max_batch_size)
                                                          max_batch_size=max_batch_size)
 
 
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
@@ -196,9 +199,7 @@ class Server(threading.Thread):
                 self.dht.run_in_background(await_ready=True)
                 self.dht.run_in_background(await_ready=True)
 
 
             if self.experts:
             if self.experts:
-                dht_handler_thread = DHTHandlerThread(
-                    experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
-                dht_handler_thread.start()
+                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
 
 
@@ -207,16 +208,10 @@ class Server(threading.Thread):
                 process.start()
                 process.start()
             process.ready.wait()
             process.ready.wait()
 
 
-        self.runtime.run()
-
-        for process in self.conn_handlers:
-            process.join()
-        if self.dht and self.experts:
-            dht_handler_thread.stop.set()
-            dht_handler_thread.join()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -242,19 +237,32 @@ class Server(threading.Thread):
 
 
     def shutdown(self):
     def shutdown(self):
         """
         """
-        Gracefully terminate a hivemind server, process-safe.
+        Gracefully terminate the server, process-safe.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
         self.ready.clear()
         self.ready.clear()
+
         for process in self.conn_handlers:
         for process in self.conn_handlers:
             process.terminate()
             process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
 
 
         if self.dht is not None:
         if self.dht is not None:
             self.dht.shutdown()
             self.dht.shutdown()
             self.dht.join()
             self.dht.join()
 
 
-        self.runtime.shutdown()
+        logger.debug(f"Shutting down runtime")
+        self.runtime.stop.set()
+        logger.info("Server shutdown succesfully")
 
 
 
 
 @contextmanager
 @contextmanager

+ 4 - 1
hivemind/server/connection_handler.py

@@ -52,7 +52,10 @@ class ConnectionHandler(mp.context.ForkProcess):
             await server.wait_for_termination()
             await server.wait_for_termination()
             logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
             logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
 
 
-        loop.run_until_complete(_run())
+        try:
+            loop.run_until_complete(_run())
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
 
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))

+ 2 - 2
hivemind/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
 
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
-        self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
-        self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
+        self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
+        self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)
 
 
         self.update_count = 0
         self.update_count = 0
         self.examples_processed = 0
         self.examples_processed = 0

+ 2 - 2
hivemind/server/expert_uid.py

@@ -62,8 +62,8 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                 else:
                 else:
                     raise ValueError("Block must be either fixed or a range [from:to]")
                     raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt as e:
-                raise e
+            except KeyboardInterrupt:
+                raise
             except Exception as e:
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
         return UID_DELIMITER.join(uid)
         return UID_DELIMITER.join(uid)

+ 17 - 20
hivemind/server/runtime.py

@@ -48,8 +48,8 @@ class Runtime(threading.Thread):
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
-        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+        self.stop = threading.Event()
 
 
         self.stats_report_interval = stats_report_interval
         self.stats_report_interval = stats_report_interval
         if self.stats_report_interval is not None:
         if self.stats_report_interval is not None:
@@ -72,62 +72,59 @@ class Runtime(threading.Thread):
 
 
                 for pool, batch_index, batch in BackgroundGenerator(
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
-                    logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
 
                     start = time()
                     start = time()
                     outputs = pool.process_func(*batch)
                     outputs = pool.process_func(*batch)
                     batch_processing_time = time() - start
                     batch_processing_time = time() - start
 
 
                     batch_size = outputs[0].size(0)
                     batch_size = outputs[0].size(0)
-                    logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
+                    logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
 
 
                     if self.stats_report_interval is not None:
                     if self.stats_report_interval is not None:
-                        self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+                        self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
 
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
             finally:
-                logger.info("Shutting down")
-
-                if self.stats_report_interval is not None:
-                    self.stats_reporter.stop.set()
-                    self.stats_reporter.join()
-
                 self.shutdown()
                 self.shutdown()
 
 
-    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
-
     def shutdown(self):
     def shutdown(self):
         """ Gracefully terminate a running runtime. """
         """ Gracefully terminate a running runtime. """
-        self.ready.clear()
-        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)  # trigger background thread to shutdown
+        logger.info("Shutting down")
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        self.stop.set()  # trigger background thread to shutdown
+
+        logger.debug("Terminating pools")
         for pool in self.pools:
         for pool in self.pools:
             if pool.is_alive():
             if pool.is_alive():
                 pool.terminate()
                 pool.terminate()
                 pool.join()
                 pool.join()
+        logger.debug("Pools terminated")
 
 
     def iterate_minibatches_from_pools(self, timeout=None):
     def iterate_minibatches_from_pools(self, timeout=None):
         """
         """
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         Chooses pool according to priority, then copies exposed batch and frees the buffer
         """
         """
         with DefaultSelector() as selector:
         with DefaultSelector() as selector:
-            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
             for pool in self.pools:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
 
 
-            while True:
+            while not self.stop.is_set():
                 # wait until at least one batch_receiver becomes available
                 # wait until at least one batch_receiver becomes available
                 logger.debug("Waiting for inputs from task pools")
                 logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
                 ready_objects = {key.data for (key, events) in ready_fds}
-                if self.SHUTDOWN_TRIGGER in ready_objects:
-                    break  # someone asked us to shutdown, break from the loop
 
 
                 logger.debug("Choosing the pool with highest priority")
                 logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
                 pool = max(ready_objects, key=lambda pool: pool.priority)
 
 
-                logger.debug(f"Loading batch from {pool.uid}")
+                logger.debug(f"Loading batch from {pool.name}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
-                logger.debug(f"Loaded batch from {pool.uid}")
+                logger.debug(f"Loaded batch from {pool.name}")
                 yield pool, batch_index, batch_tensors
                 yield pool, batch_index, batch_tensors
 
 
 
 

+ 60 - 70
hivemind/server/task_pool.py

@@ -6,7 +6,6 @@ import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 import time
 import time
-import uuid
 from abc import ABCMeta, abstractmethod
 from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from collections import namedtuple
 from concurrent.futures import Future
 from concurrent.futures import Future
@@ -24,8 +23,8 @@ Task = namedtuple("Task", ("future", "args"))
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
 
-    def __init__(self, process_func: callable, daemon=True):
-        super().__init__(daemon=daemon)
+    def __init__(self, process_func: callable, daemon=True, **kwargs):
+        super().__init__(daemon=daemon, **kwargs)
         self.process_func = process_func
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
 
@@ -63,19 +62,18 @@ class TaskPool(TaskPoolBase):
     :param process_func: function to be applied to every formed batch; called by Runtime
     :param process_func: function to be applied to every formed batch; called by Runtime
         Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
         Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+    :param name: pool name
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param timeout: wait for a subsequent task for at most this many seconds
     :param timeout: wait for a subsequent task for at most this many seconds
     :param pool_size: store at most this many unprocessed tasks in a queue
     :param pool_size: store at most this many unprocessed tasks in a queue
     :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
     :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
-    :param uid: pool identifier used for shared array allocation
     :param start: if True, start automatically at the end of __init__
     :param start: if True, start automatically at the end of __init__
     """
     """
 
 
-    def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
-                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
-        super().__init__(process_func, daemon=daemon)
+    def __init__(self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1,
+                 timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False):
+        super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
-        self.uid = uid or uuid.uuid4()
         self.prefetch_batches = prefetch_batches
         self.prefetch_batches = prefetch_batches
 
 
         # interaction with ConnectionHandlers
         # interaction with ConnectionHandlers
@@ -112,7 +110,7 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 batch = []
                 total_size = 0
                 total_size = 0
             try:
             try:
-                logger.debug(f"{self.uid} getting next task")
+                logger.debug(f"{self.name} getting next task")
                 task = self.tasks.get(timeout=self.timeout)
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
@@ -134,80 +132,72 @@ class TaskPool(TaskPoolBase):
 
 
     def run(self, *args, **kwargs):
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
         torch.set_num_threads(1)
-        logger.info(f'{self.uid} starting, pid={os.getpid()}')
+        logger.info(f'{self.name} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
+
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.uid}_output')
+                                         name=f'{self.name}_output')
+
         try:
         try:
             output_thread.start()
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
             self._pool_input_loop(pending_batches, *args, **kwargs)
-        except BaseException as e:
-            # terminate output loop
-            self.outputs_sender.send(e)
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
+        finally:
             output_thread.join()
             output_thread.join()
-            raise e
 
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
-        try:
-            prev_num_tasks = 0  # number of tasks currently in shared buffer
-            batch_index = max(pending_batches.keys(), default=0)
-            batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-            while True:
-                # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-                # assumes that tasks are processed in the same order as they are created
-                for skip_i in range(prev_num_tasks):
-                    finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
-                    if skip_i == prev_num_tasks - 1:
-                        self.priority = finished_task_timestamp
-
-                logger.debug(f"{self.uid} getting next batch")
-                batch_tasks = next(batch_iterator)
-                # save batch futures, _output_loop will deliver on them later
-                pending_batches[batch_index] = batch_tasks
-
-                logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
-                # find or create shared arrays for current batch size
-                batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
-                                range(len(batch_tasks[0].args))]
-                batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-                logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
-                self.batch_sender.send((batch_index, batch_inputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
-                prev_num_tasks = len(batch_tasks)
-                batch_index += 1
-        except KeyboardInterrupt:
-            logger.debug('Caught KeyboardInterrupt, shutting down')
+
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+        while True:
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = finished_task_timestamp
+
+            logger.debug(f"{self.name} getting next batch")
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            # find or create shared arrays for current batch size
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
+                            range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
+            self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
 
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
 
-        try:
-            while True:
-                logger.debug(f"{self.uid} waiting for results from runtime")
-                payload = self.outputs_receiver.recv()
-                if isinstance(payload, BaseException):
-                    raise payload
-                else:
-                    batch_index, batch_outputs = payload
-                logger.debug(f"{self.uid}, batch {batch_index}: got results")
-
-                # split batch into partitions for individual tasks
-                batch_tasks = pending_batches.pop(batch_index)
-                task_sizes = [self.get_task_size(task) for task in batch_tasks]
-                outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-                logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
-
-                # dispatch results to futures
-                for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                    try:
-                        task.future.set_result(tuple(task_outputs))
-                    except FutureStateError as e:
-                        logger.debug(f"Failed to send task result due to an exception: {e}")
-        except KeyboardInterrupt:
-            logger.debug(f"Caught KeyboardInterrupt, shutting down")
+        while True:
+            logger.debug(f"{self.name} waiting for results from runtime")
+            batch_index, batch_outputs = self.outputs_receiver.recv()
+            logger.debug(f"{self.name}, batch {batch_index}: got results")
+
+            # split batch into partitions for individual tasks
+            batch_tasks = pending_batches.pop(batch_index)
+            task_sizes = [self.get_task_size(task) for task in batch_tasks]
+            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
+
+            # dispatch results to futures
+            for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                try:
+                    task.future.set_result(tuple(task_outputs))
+                except FutureStateError as e:
+                    logger.debug(f"Failed to send task result due to an exception: {e}")
 
 
     @property
     @property
     def empty(self):
     def empty(self):

+ 215 - 0
hivemind/utils/auth.py

@@ -0,0 +1,215 @@
+import asyncio
+import functools
+import secrets
+import threading
+import time
+from abc import ABC, abstractmethod
+from enum import Enum
+from datetime import timedelta
+from typing import Optional, Tuple
+
+import grpc
+
+from hivemind.proto.auth_pb2 import AccessToken, RequestAuthInfo, ResponseAuthInfo
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.logging import get_logger
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time
+
+
+logger = get_logger(__name__)
+
+
+class AuthorizedRequestBase:
+    """
+    Interface for protobufs with the ``RequestAuthInfo auth`` field. Used for type annotations only.
+    """
+
+    auth: RequestAuthInfo
+
+
+class AuthorizedResponseBase:
+    """
+    Interface for protobufs with the ``ResponseAuthInfo auth`` field. Used for type annotations only.
+    """
+
+    auth: ResponseAuthInfo
+
+
+class AuthorizerBase(ABC):
+    @abstractmethod
+    async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
+        ...
+
+    @abstractmethod
+    async def validate_request(self, request: AuthorizedRequestBase) -> bool:
+        ...
+
+    @abstractmethod
+    async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
+        ...
+
+    @abstractmethod
+    async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
+        ...
+
+
+class TokenAuthorizerBase(AuthorizerBase):
+    """
+    Implements the authorization protocol for a moderated Hivemind network.
+    See https://github.com/learning-at-home/hivemind/issues/253
+    """
+
+    def __init__(self, local_private_key: Optional[RSAPrivateKey]=None):
+        if local_private_key is None:
+            local_private_key = RSAPrivateKey.process_wide()
+        self._local_private_key = local_private_key
+        self._local_public_key = local_private_key.get_public_key()
+
+        self._local_access_token = None
+        self._refresh_lock = asyncio.Lock()
+
+        self._recent_nonces = TimedStorage()
+
+    @abstractmethod
+    async def get_token(self) -> AccessToken:
+        ...
+
+    @abstractmethod
+    def is_token_valid(self, access_token: AccessToken) -> bool:
+        ...
+
+    @abstractmethod
+    def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
+        ...
+
+    async def refresh_token_if_needed(self) -> None:
+        if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
+            async with self._refresh_lock:
+                if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
+                    self._local_access_token = await self.get_token()
+                    assert self.is_token_valid(self._local_access_token)
+
+    @property
+    def local_public_key(self) -> RSAPublicKey:
+        return self._local_public_key
+
+    async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
+        await self.refresh_token_if_needed()
+        auth = request.auth
+
+        auth.client_access_token.CopyFrom(self._local_access_token)
+
+        if service_public_key is not None:
+            auth.service_public_key = service_public_key.to_bytes()
+        auth.time = get_dht_time()
+        auth.nonce = secrets.token_bytes(8)
+
+        assert auth.signature == b''
+        auth.signature = self._local_private_key.sign(request.SerializeToString())
+
+    _MAX_CLIENT_SERVICER_TIME_DIFF = timedelta(minutes=1)
+
+    async def validate_request(self, request: AuthorizedRequestBase) -> bool:
+        await self.refresh_token_if_needed()
+        auth = request.auth
+
+        if not self.is_token_valid(auth.client_access_token):
+            logger.debug('Client failed to prove that it (still) has access to the network')
+            return False
+
+        client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
+        signature = auth.signature
+        auth.signature = b''
+        if not client_public_key.verify(request.SerializeToString(), signature):
+            logger.debug('Request has invalid signature')
+            return False
+
+        if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
+            logger.debug('Request is generated for a peer with another public key')
+            return False
+
+        with self._recent_nonces.freeze():
+            current_time = get_dht_time()
+            if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
+                logger.debug('Clocks are not synchronized or a previous request is replayed again')
+                return False
+            if auth.nonce in self._recent_nonces:
+                logger.debug('Previous request is replayed again')
+                return False
+
+        self._recent_nonces.store(auth.nonce, None,
+                                  current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3)
+        return True
+
+    async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
+        await self.refresh_token_if_needed()
+        auth = response.auth
+
+        auth.service_access_token.CopyFrom(self._local_access_token)
+        auth.nonce = request.auth.nonce
+
+        assert auth.signature == b''
+        auth.signature = self._local_private_key.sign(response.SerializeToString())
+
+    async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
+        await self.refresh_token_if_needed()
+        auth = response.auth
+
+        if not self.is_token_valid(auth.service_access_token):
+            logger.debug('Service failed to prove that it (still) has access to the network')
+            return False
+
+        service_public_key = RSAPublicKey.from_bytes(auth.service_access_token.public_key)
+        signature = auth.signature
+        auth.signature = b''
+        if not service_public_key.verify(response.SerializeToString(), signature):
+            logger.debug('Response has invalid signature')
+            return False
+
+        if auth.nonce != request.auth.nonce:
+            logger.debug('Response is generated for another request')
+            return False
+
+        return True
+
+
+class AuthRole(Enum):
+    CLIENT = 0
+    SERVICER = 1
+
+
+class AuthRPCWrapper:
+    def __init__(self, stub, role: AuthRole,
+                 authorizer: Optional[AuthorizerBase], service_public_key: Optional[RSAPublicKey]=None):
+        self._stub = stub
+        self._role = role
+        self._authorizer = authorizer
+        self._service_public_key = service_public_key
+
+    def __getattribute__(self, name: str):
+        if not name.startswith('rpc_'):
+            return object.__getattribute__(self, name)
+
+        method = getattr(self._stub, name)
+
+        @functools.wraps(method)
+        async def wrapped_rpc(request: AuthorizedRequestBase, *args, **kwargs):
+            if self._authorizer is not None:
+                if self._role == AuthRole.CLIENT:
+                    await self._authorizer.sign_request(request, self._service_public_key)
+                elif self._role == AuthRole.SERVICER:
+                    if not await self._authorizer.validate_request(request):
+                        return None
+
+            response = await method(request, *args, **kwargs)
+
+            if self._authorizer is not None:
+                if self._role == AuthRole.SERVICER:
+                    await self._authorizer.sign_response(response, request)
+                elif self._role == AuthRole.CLIENT:
+                    if not await self._authorizer.validate_response(response, request):
+                        return None
+
+            return response
+
+        return wrapped_rpc

+ 102 - 0
hivemind/utils/crypto.py

@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+import base64
+import threading
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from cryptography import exceptions
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+
+class PrivateKey(ABC):
+    @abstractmethod
+    def sign(self, data: bytes) -> bytes:
+        ...
+
+    @abstractmethod
+    def get_public_key(self) -> PublicKey:
+        ...
+
+
+class PublicKey(ABC):
+    @abstractmethod
+    def verify(self, data: bytes, signature: bytes) -> bool:
+        ...
+
+    @abstractmethod
+    def to_bytes(self) -> bytes:
+        ...
+
+    @classmethod
+    @abstractmethod
+    def from_bytes(cls, key: bytes) -> bytes:
+        ...
+
+
+_RSA_PADDING = padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH)
+_RSA_HASH_ALGORITHM = hashes.SHA256()
+
+
+class RSAPrivateKey(PrivateKey):
+    def __init__(self):
+        self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+
+    _process_wide_key = None
+    _process_wide_key_lock = threading.RLock()
+
+    @classmethod
+    def process_wide(cls) -> RSAPrivateKey:
+        if cls._process_wide_key is None:
+            with cls._process_wide_key_lock:
+                if cls._process_wide_key is None:
+                    cls._process_wide_key = cls()
+        return cls._process_wide_key
+
+    def sign(self, data: bytes) -> bytes:
+        signature = self._private_key.sign(data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
+        return base64.b64encode(signature)
+
+    def get_public_key(self) -> RSAPublicKey:
+        return RSAPublicKey(self._private_key.public_key())
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Serializes the private key to make the class instances picklable
+        state['_private_key'] = self._private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.OpenSSH,
+            encryption_algorithm=serialization.NoEncryption())
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+
+
+class RSAPublicKey(PublicKey):
+    def __init__(self, public_key: rsa.RSAPublicKey):
+        self._public_key = public_key
+
+    def verify(self, data: bytes, signature: bytes) -> bool:
+        try:
+            signature = base64.b64decode(signature)
+
+            # Returns None if the signature is correct, raises an exception otherwise
+            self._public_key.verify(signature, data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
+
+            return True
+        except (ValueError, exceptions.InvalidSignature):
+            return False
+
+    def to_bytes(self) -> bytes:
+        return self._public_key.public_bytes(
+            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
+
+    @classmethod
+    def from_bytes(cls, key: bytes) -> RSAPublicKey:
+        key = serialization.load_ssh_public_key(key)
+        if not isinstance(key, rsa.RSAPublicKey):
+            raise ValueError(f'Expected an RSA public key, got {key}')
+        return cls(key)

+ 161 - 0
tests/test_auth.py

@@ -0,0 +1,161 @@
+from datetime import datetime, timedelta
+from typing import Optional, Tuple
+
+import pytest
+
+from hivemind.proto import dht_pb2
+from hivemind.proto.auth_pb2 import AccessToken
+from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class MockAuthorizer(TokenAuthorizerBase):
+    _authority_private_key = None
+    _authority_public_key = None
+
+    def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str='mock'):
+        super().__init__(local_private_key)
+
+        self._username = username
+        self._authority_public_key = None
+
+    async def get_token(self) -> AccessToken:
+        if MockAuthorizer._authority_private_key is None:
+            MockAuthorizer._authority_private_key = RSAPrivateKey()
+
+        self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()
+
+        token = AccessToken(username=self._username,
+                            public_key=self.local_public_key.to_bytes(),
+                            expiration_time=str(datetime.utcnow() + timedelta(minutes=1)))
+        token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token))
+        return token
+
+    def is_token_valid(self, access_token: AccessToken) -> bool:
+        data = self._token_to_bytes(access_token)
+        if not self._authority_public_key.verify(data, access_token.signature):
+            logger.exception('Access token has invalid signature')
+            return False
+
+        try:
+            expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        except ValueError:
+            logger.exception(
+                f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time.tzinfo is not None:
+            logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
+            return False
+        if expiration_time < datetime.utcnow():
+            logger.exception('Access token has expired')
+            return False
+
+        return True
+
+    _MAX_LATENCY = timedelta(minutes=1)
+
+    def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
+        expiration_time = datetime.fromisoformat(access_token.expiration_time)
+        return expiration_time < datetime.utcnow() + self._MAX_LATENCY
+
+    @staticmethod
+    def _token_to_bytes(access_token: AccessToken) -> bytes:
+        return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
+
+
+@pytest.mark.asyncio
+async def test_valid_request_and_response():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+    assert await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+    assert await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_invalid_access_token():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+
+    # Break the access token signature
+    request.auth.client_access_token.signature = b'broken'
+
+    assert not await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+
+    # Break the access token signature
+    response.auth.service_access_token.signature = b'broken'
+
+    assert not await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_invalid_signatures():
+    client_authorizer = MockAuthorizer(RSAPrivateKey())
+    service_authorizer = MockAuthorizer(RSAPrivateKey())
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:7777'
+    await client_authorizer.sign_request(request, service_authorizer.local_public_key)
+
+    # A man-in-the-middle attacker changes the request content
+    request.peer.endpoint = '127.0.0.2:7777'
+
+    assert not await service_authorizer.validate_request(request)
+
+    response = dht_pb2.PingResponse()
+    response.sender_endpoint = '127.0.0.1:31337'
+    await service_authorizer.sign_response(response, request)
+
+    # A man-in-the-middle attacker changes the response content
+    response.sender_endpoint = '127.0.0.2:31337'
+
+    assert not await client_authorizer.validate_response(response, request)
+
+
+@pytest.mark.asyncio
+async def test_auth_rpc_wrapper():
+    class Servicer:
+        async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
+            assert request.peer.endpoint == '127.0.0.1:1111'
+            assert request.auth.client_access_token.username == 'alice'
+
+            response = dht_pb2.PingResponse()
+            response.sender_endpoint = '127.0.0.1:2222'
+            return response
+
+    class Client:
+        def __init__(self, servicer: Servicer):
+            self._servicer = servicer
+
+        async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
+            return await self._servicer.rpc_increment(request)
+
+    servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), 'bob'))
+    client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice'))
+
+    request = dht_pb2.PingRequest()
+    request.peer.endpoint = '127.0.0.1:1111'
+
+    response = await client.rpc_increment(request)
+
+    assert response.sender_endpoint == '127.0.0.1:2222'
+    assert response.auth.service_access_token.username == 'bob'

+ 6 - 5
tests/test_dht_crypto.py

@@ -9,12 +9,13 @@ from hivemind.dht import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import LOCALHOST
 from hivemind.dht.node import LOCALHOST
 from hivemind.dht.validation import DHTRecord
 from hivemind.dht.validation import DHTRecord
+from hivemind.utils.crypto import RSAPrivateKey
 
 
 
 
 def test_rsa_signature_validator():
 def test_rsa_signature_validator():
     receiver_validator = RSASignatureValidator()
     receiver_validator = RSASignatureValidator()
-    sender_validator = RSASignatureValidator(ignore_cached_key=True)
-    mallory_validator = RSASignatureValidator(ignore_cached_key=True)
+    sender_validator = RSASignatureValidator(RSAPrivateKey())
+    mallory_validator = RSASignatureValidator(RSAPrivateKey())
 
 
     plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
     plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
                              expiration_time=get_dht_time() + 10)
                              expiration_time=get_dht_time() + 10)
@@ -52,7 +53,7 @@ def test_cached_key():
     second_validator = RSASignatureValidator()
     second_validator = RSASignatureValidator()
     assert first_validator.local_public_key == second_validator.local_public_key
     assert first_validator.local_public_key == second_validator.local_public_key
 
 
-    third_validator = RSASignatureValidator(ignore_cached_key=True)
+    third_validator = RSASignatureValidator(RSAPrivateKey())
     assert first_validator.local_public_key != third_validator.local_public_key
     assert first_validator.local_public_key != third_validator.local_public_key
 
 
 
 
@@ -105,10 +106,10 @@ def test_signing_in_different_process():
 async def test_dhtnode_signatures():
 async def test_dhtnode_signatures():
     alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
     alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
     bob = await hivemind.DHTNode.create(
     bob = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
     mallory = await hivemind.DHTNode.create(
     mallory = await hivemind.DHTNode.create(
-        record_validator=RSASignatureValidator(ignore_cached_key=True),
+        record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
 
 
     key = b'key'
     key = b'key'