Sfoglia il codice sorgente

Add DHT listening options to server

Max Ryabinin 4 anni fa
parent
commit
05c4355f8a

+ 4 - 0
hivemind/hivemind_cli/run_server.py

@@ -57,6 +57,10 @@ def main():
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
 
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
+
+    parser.add_argument('--dht_port', type=int)
+    parser.add_argument('--dht_listen_on', type=str)
+
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',
     parser.add_argument('--increase_file_limit', action='store_true',

+ 1 - 1
hivemind/moe/client/moe.py

@@ -15,7 +15,7 @@ from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_map, nested_pack, nested_compare
+from hivemind.utils import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 18 - 1
hivemind/moe/server/__init__.py

@@ -120,6 +120,8 @@ class Server(threading.Thread):
         reuse_grad_buffers=True,
         reuse_grad_buffers=True,
         device=None,
         device=None,
         no_dht=False,
         no_dht=False,
+        dht_port=None,
+        dht_listen_on=None,
         initial_peers=(),
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
@@ -175,7 +177,22 @@ class Server(threading.Thread):
         if no_dht:
         if no_dht:
             dht = None
             dht = None
         else:
         else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True, identity_path=identity_path)
+            dht_port = dht_port or hivemind.get_free_port()
+            host_maddrs = [f"/ip4/0.0.0.0/tcp/{dht_port}"]
+            announce_maddrs = []
+
+            if dht_listen_on is not None:
+                dht_maddr = f"/ip6/{dht_listen_on}/tcp/{dht_port}"
+                host_maddrs.append(dht_maddr)
+                announce_maddrs.append(dht_maddr)
+
+            dht = hivemind.DHT(
+                initial_peers=initial_peers,
+                start=True,
+                identity_path=identity_path,
+                host_maddrs=host_maddrs,
+                announce_maddrs=announce_maddrs,
+            )
             visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
             visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
 

+ 57 - 59
hivemind/moe/server/layers/albert.py

@@ -61,12 +61,12 @@ class LeanFFN(nn.Module):
     """
     """
 
 
     def __init__(
     def __init__(
-            self,
-            hidden_size: int,
-            intermediate_size: int,
-            activation=F.gelu,
-            gated: bool = False,
-            layer_norm_eps: float = 1e-12,
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        activation=F.gelu,
+        gated: bool = False,
+        layer_norm_eps: float = 1e-12,
     ):
     ):
         super().__init__()
         super().__init__()
         self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
         self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
@@ -103,17 +103,17 @@ class _LeanFFN(torch.autograd.Function):
     @staticmethod
     @staticmethod
     @custom_fwd
     @custom_fwd
     def forward(
     def forward(
-            ctx,
-            input,
-            ln_weight,
-            ln_bias,
-            i2h_weight,
-            i2h_bias,
-            h2o_weight,
-            h2o_bias,
-            activation,
-            training,
-            ln_eps,
+        ctx,
+        input,
+        ln_weight,
+        ln_bias,
+        i2h_weight,
+        i2h_bias,
+        h2o_weight,
+        h2o_bias,
+        activation,
+        training,
+        ln_eps,
     ):
     ):
         ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
         ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
         ctx._cpu_rng_state = torch.get_rng_state()
         ctx._cpu_rng_state = torch.get_rng_state()
@@ -179,17 +179,17 @@ class _LeanFFN(torch.autograd.Function):
             grad_h2o_bias = grad_output_2d.sum(0)
             grad_h2o_bias = grad_output_2d.sum(0)
 
 
         return (
         return (
-                grad_input,
-                grad_ln_weight,
-                grad_ln_bias,
-                grad_i2h_weight,
-                grad_i2h_bias,
-                grad_h2o_weight,
-                grad_h2o_bias,
-                None,
-                None,
-                None,
-                None,
+            grad_input,
+            grad_ln_weight,
+            grad_ln_bias,
+            grad_i2h_weight,
+            grad_i2h_bias,
+            grad_h2o_weight,
+            grad_h2o_bias,
+            None,
+            None,
+            None,
+            None,
         )
         )
 
 
 
 
@@ -212,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
             self.register_buffer("cos", cos)
             self.register_buffer("cos", cos)
             self.register_buffer("sin", sin)
             self.register_buffer("sin", sin)
 
 
-        return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
+        return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
 
 
 
 
 @torch.no_grad()
 @torch.no_grad()
@@ -243,13 +243,13 @@ def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tenso
 
 
 class LeanSelfAttention(nn.Module):
 class LeanSelfAttention(nn.Module):
     def __init__(
     def __init__(
-            self,
-            hidden_size: int,
-            num_attention_heads: int,
-            max_positions: int,
-            attention_core: Optional[nn.Module] = None,
-            layer_norm_eps: float = 1e-12,
-            **kwargs,
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        max_positions: int,
+        attention_core: Optional[nn.Module] = None,
+        layer_norm_eps: float = 1e-12,
+        **kwargs,
     ):
     ):
         """Attention layer that does not hog GPU memory"""
         """Attention layer that does not hog GPU memory"""
         super().__init__()
         super().__init__()
@@ -311,7 +311,7 @@ class SimpleAttentionCore(nn.Module):
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
 
 
         query_length, key_length = query.size(-2), key.size(-2)
         query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
+        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
 
 
         if attention_mask is not None:
         if attention_mask is not None:
@@ -330,12 +330,12 @@ class RotaryAttentionCore(SimpleAttentionCore):
     """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
     """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
 
 
     def __init__(
     def __init__(
-            self,
-            hidden_size: int,
-            num_attention_heads: int,
-            max_positions: int,
-            rotary_emb: Optional[RotaryEmbeddings] = None,
-            **kwargs,
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        max_positions: int,
+        rotary_emb: Optional[RotaryEmbeddings] = None,
+        **kwargs,
     ):
     ):
         super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
         super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
         if rotary_emb is None:
         if rotary_emb is None:
@@ -393,7 +393,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
 
     # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
     # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
     def forward(
     def forward(
-            self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
     ):
     ):
         if input_ids is not None:
         if input_ids is not None:
             input_shape = input_ids.size()
             input_shape = input_ids.size()
@@ -413,7 +413,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
 
         if self.position_embeddings is not None:
         if self.position_embeddings is not None:
             if position_ids is None:
             if position_ids is None:
-                position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
+                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
             position_embeddings = self.position_embeddings(position_ids)
             position_embeddings = self.position_embeddings(position_ids)
             embeddings += position_embeddings
             embeddings += position_embeddings
 
 
@@ -458,8 +458,7 @@ class LeanAlbertLayerGroup(AlbertLayerGroup):
         self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
         self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
 
 
     def forward(
     def forward(
-            self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False,
-            output_hidden_states=False
+        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
     ):
     ):
         if head_mask is not None and any(head_mask):
         if head_mask is not None and any(head_mask):
             raise NotImplementedError(f"head mask was provided, but it is not supported")
             raise NotImplementedError(f"head mask was provided, but it is not supported")
@@ -496,13 +495,13 @@ class LeanAlbertTransformer(AlbertTransformer):
         self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
         self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
 
 
     def forward(
     def forward(
-            self,
-            hidden_states,
-            attention_mask=None,
-            head_mask=None,
-            output_attentions=False,
-            output_hidden_states=False,
-            return_dict=True,
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
     ):
     ):
         # TODO this should entire be replaced with inheritance and post_layer_norm
         # TODO this should entire be replaced with inheritance and post_layer_norm
         hidden_states = self.embedding_hidden_mapping_in(hidden_states)
         hidden_states = self.embedding_hidden_mapping_in(hidden_states)
@@ -585,7 +584,7 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 
 head_sample_input = lambda batch_size, hid_dim: (
 head_sample_input = lambda batch_size, hid_dim: (
-        torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
+    torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
 )
 )
 
 
 
 
@@ -602,7 +601,7 @@ class HeadExpert(nn.Module):
 
 
     def forward(self, input_ids):
     def forward(self, input_ids):
         embedding_output = self.embeddings(input_ids)
         embedding_output = self.embeddings(input_ids)
-        encoder_outputs, = self.encoder(embedding_output, return_dict=False)
+        (encoder_outputs,) = self.encoder(embedding_output, return_dict=False)
 
 
         return encoder_outputs
         return encoder_outputs
 
 
@@ -644,9 +643,8 @@ class BodyExpert(nn.Module):
 
 
 
 
 tail_sample_input = lambda batch_size, hid_dim: (
 tail_sample_input = lambda batch_size, hid_dim: (
-
-        torch.empty((batch_size, 512, hid_dim)),
-        torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
+    torch.empty((batch_size, 512, hid_dim)),
+    torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
 )
 )
 
 
 
 

+ 2 - 1
hivemind/utils/nested.py

@@ -1,6 +1,7 @@
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
 import torch
 import torch
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 
 
 
 
 def nested_compare(t, u):
 def nested_compare(t, u):