|
@@ -5,19 +5,19 @@ from test_utils.run_server import background_server
|
|
|
|
|
|
def test_remote_module_call():
|
|
|
""" Check that remote_module_call returns correct outputs and gradients if called directly """
|
|
|
+ num_experts = 8
|
|
|
+ k_min = 1
|
|
|
+ timeout_after_k_min = None
|
|
|
+ backward_k_min = 1
|
|
|
+ timeout_total = None
|
|
|
+ backward_timeout = None
|
|
|
+
|
|
|
xx = torch.randn(32, 1024, requires_grad=True)
|
|
|
logits = torch.randn(3, requires_grad=True)
|
|
|
random_proj = torch.randn_like(xx)
|
|
|
- num_experts = 8
|
|
|
-
|
|
|
- with background_server(num_experts=num_experts, no_optimizer=True, no_network=True, verbose=True) as server:
|
|
|
- experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server.port) for i in range(num_experts)]
|
|
|
|
|
|
- k_min = 1
|
|
|
- timeout_after_k_min = None
|
|
|
- backward_k_min = 1
|
|
|
- timeout_total = None
|
|
|
- backward_timeout = None
|
|
|
+ with background_server(num_experts=num_experts, no_optimizer=True, no_network=True) as localhost, server_port:
|
|
|
+ experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
|
|
|
moe_output, = tesseract.client.moe._RemoteMoECall.apply(
|
|
|
logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
|
|
|
[(None,), {}], xx)
|