소스 검색

move to notes

justheuristic 5 년 전
부모
커밋
dfa9dfaae2
2개의 변경된 파일80개의 추가작업 그리고 22개의 파일을 삭제
  1. 25 22
      tests/test_moe.py
  2. 55 0
      tests/test_utils/run_server.py

+ 25 - 22
tests/test_moe.py

@@ -1,5 +1,6 @@
 import torch
 import tesseract
+from test_utils.run_server import background_server
 
 
 def test_remote_module_call():
@@ -7,28 +8,30 @@ def test_remote_module_call():
     xx = torch.randn(32, 1024, requires_grad=True)
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
-    # TODO somehow start server on some port
-    experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=6565) for i in range(8)]
-
-    k_min = 1
-    timeout_after_k_min = None
-    backward_k_min = 1
-    timeout_total = None
-    backward_timeout = None
-    moe_output, = tesseract.client.moe._RemoteMoECall.apply(
-        logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
-        [(None,), {}], xx)
-
-    grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
-    grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
-
-    # reference outputs: call all experts manually and average their outputs with softmax probabilities
-    probs = torch.softmax(logits, 0)
-    outs = [expert(xx) for expert in experts[:3]]
-    manual_output = sum(p * x for p, x in zip(probs, outs))
-    grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-    grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-    grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
+    num_experts = 8
+
+    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True, verbose=True) as server:
+        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server.port) for i in range(num_experts)]
+
+        k_min = 1
+        timeout_after_k_min = None
+        backward_k_min = 1
+        timeout_total = None
+        backward_timeout = None
+        moe_output, = tesseract.client.moe._RemoteMoECall.apply(
+            logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
+            [(None,), {}], xx)
+
+        grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
+        grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
+
+        # reference outputs: call all experts manually and average their outputs with softmax probabilities
+        probs = torch.softmax(logits, 0)
+        outs = [expert(xx) for expert in experts[:3]]
+        manual_output = sum(p * x for p, x in zip(probs, outs))
+        grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+        grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+        grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
 
     assert torch.allclose(moe_output, manual_output), "_RemoteMoECall returned incorrect output"
     assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \

+ 55 - 0
tests/test_utils/run_server.py

@@ -0,0 +1,55 @@
+import torch
+
+import tesseract
+from .layers import name_to_block
+from contextlib import contextmanager
+
+
+@contextmanager
+def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
+                      expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
+                      no_network=False, initial_peers=(), network_port=None, verbose=False, **kwargs
+                      ) -> tesseract.TesseractServer:
+    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
+    if verbose and len(kwargs) == 0:
+        print("Ignored kwargs:", kwargs)
+    expert_cls in name_to_block
+    num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+
+    # initialize network
+    network = None
+    if not no_network:
+        initial_peers = eval(initial_peers)
+        network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(), start=True)
+        if verbose:
+            print("Parsed initial peers:", initial_peers)
+            print(f"Running network node on port {network.port}")
+
+    # initialize experts
+    experts = {}
+    for i in range(num_experts):
+        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
+        opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
+        expert_uid = f'{expert_prefix}.{i + expert_offset}'
+        experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
+                                                      args_schema=(tesseract.BatchTensorProto(hidden_dim),),
+                                                      outputs_schema=tesseract.BatchTensorProto(hidden_dim),
+                                                      max_batch_size=max_batch_size,
+                                                      )
+    # start server
+    server = tesseract.TesseractServer(
+        network, experts, addr=host, port=port or tesseract.find_open_port(),
+        conn_handler_processes=num_handlers, device=device)
+    try:
+        server.run_in_background(await_ready=True)
+        if verbose:
+            print(f"Running server at {server.addr}:{server.port}")
+            print(f"Active experts of type {expert_cls}: {list(experts.keys())}")
+        yield server
+    finally:
+        if verbose:
+            print("Shutting down server...")
+        server.shutdown()
+        if verbose:
+            print("Server shut down successfully.")