Browse Source

Add DHT listening options to server

Max Ryabinin 4 years ago
parent
commit
05c4355f8a

+ 4 - 0
hivemind/hivemind_cli/run_server.py

@@ -57,6 +57,10 @@ def main():
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
+
+    parser.add_argument('--dht_port', type=int)
+    parser.add_argument('--dht_listen_on', type=str)
+
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',

+ 1 - 1
hivemind/moe/client/moe.py

@@ -15,7 +15,7 @@ from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_map, nested_pack, nested_compare
+from hivemind.utils import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 18 - 1
hivemind/moe/server/__init__.py

@@ -120,6 +120,8 @@ class Server(threading.Thread):
         reuse_grad_buffers=True,
         device=None,
         no_dht=False,
+        dht_port=None,
+        dht_listen_on=None,
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
@@ -175,7 +177,22 @@ class Server(threading.Thread):
         if no_dht:
             dht = None
         else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True, identity_path=identity_path)
+            dht_port = dht_port or hivemind.get_free_port()
+            host_maddrs = [f"/ip4/0.0.0.0/tcp/{dht_port}"]
+            announce_maddrs = []
+
+            if dht_listen_on is not None:
+                dht_maddr = f"/ip6/{dht_listen_on}/tcp/{dht_port}"
+                host_maddrs.append(dht_maddr)
+                announce_maddrs.append(dht_maddr)
+
+            dht = hivemind.DHT(
+                initial_peers=initial_peers,
+                start=True,
+                identity_path=identity_path,
+                host_maddrs=host_maddrs,
+                announce_maddrs=announce_maddrs,
+            )
             visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 

+ 57 - 59
hivemind/moe/server/layers/albert.py

@@ -61,12 +61,12 @@ class LeanFFN(nn.Module):
     """
 
     def __init__(
-            self,
-            hidden_size: int,
-            intermediate_size: int,
-            activation=F.gelu,
-            gated: bool = False,
-            layer_norm_eps: float = 1e-12,
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        activation=F.gelu,
+        gated: bool = False,
+        layer_norm_eps: float = 1e-12,
     ):
         super().__init__()
         self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
@@ -103,17 +103,17 @@ class _LeanFFN(torch.autograd.Function):
     @staticmethod
     @custom_fwd
     def forward(
-            ctx,
-            input,
-            ln_weight,
-            ln_bias,
-            i2h_weight,
-            i2h_bias,
-            h2o_weight,
-            h2o_bias,
-            activation,
-            training,
-            ln_eps,
+        ctx,
+        input,
+        ln_weight,
+        ln_bias,
+        i2h_weight,
+        i2h_bias,
+        h2o_weight,
+        h2o_bias,
+        activation,
+        training,
+        ln_eps,
     ):
         ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
         ctx._cpu_rng_state = torch.get_rng_state()
@@ -179,17 +179,17 @@ class _LeanFFN(torch.autograd.Function):
             grad_h2o_bias = grad_output_2d.sum(0)
 
         return (
-                grad_input,
-                grad_ln_weight,
-                grad_ln_bias,
-                grad_i2h_weight,
-                grad_i2h_bias,
-                grad_h2o_weight,
-                grad_h2o_bias,
-                None,
-                None,
-                None,
-                None,
+            grad_input,
+            grad_ln_weight,
+            grad_ln_bias,
+            grad_i2h_weight,
+            grad_i2h_bias,
+            grad_h2o_weight,
+            grad_h2o_bias,
+            None,
+            None,
+            None,
+            None,
         )
 
 
@@ -212,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
             self.register_buffer("cos", cos)
             self.register_buffer("sin", sin)
 
-        return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
+        return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
 
 
 @torch.no_grad()
@@ -243,13 +243,13 @@ def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tenso
 
 class LeanSelfAttention(nn.Module):
     def __init__(
-            self,
-            hidden_size: int,
-            num_attention_heads: int,
-            max_positions: int,
-            attention_core: Optional[nn.Module] = None,
-            layer_norm_eps: float = 1e-12,
-            **kwargs,
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        max_positions: int,
+        attention_core: Optional[nn.Module] = None,
+        layer_norm_eps: float = 1e-12,
+        **kwargs,
     ):
         """Attention layer that does not hog GPU memory"""
         super().__init__()
@@ -311,7 +311,7 @@ class SimpleAttentionCore(nn.Module):
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
 
         query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
+        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
 
         if attention_mask is not None:
@@ -330,12 +330,12 @@ class RotaryAttentionCore(SimpleAttentionCore):
     """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
 
     def __init__(
-            self,
-            hidden_size: int,
-            num_attention_heads: int,
-            max_positions: int,
-            rotary_emb: Optional[RotaryEmbeddings] = None,
-            **kwargs,
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        max_positions: int,
+        rotary_emb: Optional[RotaryEmbeddings] = None,
+        **kwargs,
     ):
         super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
         if rotary_emb is None:
@@ -393,7 +393,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
     # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
     def forward(
-            self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
     ):
         if input_ids is not None:
             input_shape = input_ids.size()
@@ -413,7 +413,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
         if self.position_embeddings is not None:
             if position_ids is None:
-                position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
+                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
             position_embeddings = self.position_embeddings(position_ids)
             embeddings += position_embeddings
 
@@ -458,8 +458,7 @@ class LeanAlbertLayerGroup(AlbertLayerGroup):
         self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
 
     def forward(
-            self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False,
-            output_hidden_states=False
+        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
     ):
         if head_mask is not None and any(head_mask):
             raise NotImplementedError(f"head mask was provided, but it is not supported")
@@ -496,13 +495,13 @@ class LeanAlbertTransformer(AlbertTransformer):
         self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
 
     def forward(
-            self,
-            hidden_states,
-            attention_mask=None,
-            head_mask=None,
-            output_attentions=False,
-            output_hidden_states=False,
-            return_dict=True,
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
     ):
         # TODO this should entire be replaced with inheritance and post_layer_norm
         hidden_states = self.embedding_hidden_mapping_in(hidden_states)
@@ -585,7 +584,7 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 head_sample_input = lambda batch_size, hid_dim: (
-        torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
+    torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
 )
 
 
@@ -602,7 +601,7 @@ class HeadExpert(nn.Module):
 
     def forward(self, input_ids):
         embedding_output = self.embeddings(input_ids)
-        encoder_outputs, = self.encoder(embedding_output, return_dict=False)
+        (encoder_outputs,) = self.encoder(embedding_output, return_dict=False)
 
         return encoder_outputs
 
@@ -644,9 +643,8 @@ class BodyExpert(nn.Module):
 
 
 tail_sample_input = lambda batch_size, hid_dim: (
-
-        torch.empty((batch_size, 512, hid_dim)),
-        torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
+    torch.empty((batch_size, 512, hid_dim)),
+    torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
 )
 
 

+ 2 - 1
hivemind/utils/nested.py

@@ -1,6 +1,7 @@
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
 import torch
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 
 
 def nested_compare(t, u):