Ver código fonte

background_server: return network_port as well

justheuristic 5 anos atrás
pai
commit
edcfe67bf2
3 arquivos alterados com 2 adições e 3 exclusões
  1. 0 1
      tesseract/client/moe.py
  2. 1 1
      tests/test_moe.py
  3. 1 1
      tests/test_utils/run_server.py

+ 0 - 1
tesseract/client/moe.py

@@ -169,7 +169,6 @@ class RemoteMixtureOfExperts(nn.Module):
             for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
-
         scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores

+ 1 - 1
tests/test_moe.py

@@ -19,7 +19,7 @@ def test_remote_module_call():
     random_proj = torch.randn_like(xx)
 
     with background_server(num_experts=num_experts,  device='cpu',
-                           no_optimizer=True, no_network=True) as (localhost, server_port):
+                           no_optimizer=True, no_network=True) as (localhost, server_port, network_port):
         experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = tesseract.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,

+ 1 - 1
tests/test_utils/run_server.py

@@ -71,7 +71,7 @@ def background_server(*args, verbose=True, **kwargs):
     def server_runner():
         try:
             server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
-            send_addr.send((server.addr, server.port))
+            send_addr.send((server.addr, server.port, server.network.port))
             trigger_shutdown.wait()
         finally:
             if verbose: