Ver código fonte

get output schema lazily

justheuristic 5 anos atrás
pai
commit
640177786f
2 arquivos alterados com 34 adições e 27 exclusões
  1. 14 6
      tesseract/client/moe.py
  2. 20 21
      tests/test_moe.py

+ 14 - 6
tesseract/client/moe.py

@@ -48,10 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
         self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-
-        # grab some expert to set ensemble output shape
-        dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
-        self.output_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
+        self._outputs_schema = None
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
@@ -142,12 +139,15 @@ class RemoteMixtureOfExperts(nn.Module):
 
         unique_experts = self.network.get_experts(list(set(
             uid for row in beam for uid in row if uid != self.expert_padding)))
+        if self._outputs_schema is None:
+            self._output_schema = next(iter(unique_experts)).info['output_schema']
         unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
 
         return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
 
-    def compute_expert_scores(self, grid_scores: List[torch.Tensor],
-                              batch_experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
+    def compute_expert_scores(
+            self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        """ TODO docstring here """
         expert_counts = list(map(len, batch_experts))
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
@@ -173,6 +173,14 @@ class RemoteMixtureOfExperts(nn.Module):
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
 
+    @property
+    def output_schema(self):
+        if self._outputs_schema is None:
+            # grab some expert to set ensemble output shape
+            dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
+            self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
+        return self._outputs_schema
+
 
 class _RemoteMoECall(torch.autograd.Function):
     """

+ 20 - 21
tests/test_moe.py

@@ -44,28 +44,27 @@ def test_remote_module_call():
 
 
 def test_compute_expert_scores():
-    with background_server(device='cpu') as (server_addr, server_port, network_port):
-        try:
-            dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
-            moe = tesseract.client.moe.RemoteMixtureOfExperts(
-                network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
-                uid_prefix='expert')
-            gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
-            ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
-            jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
-            batch_experts = [
-                [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
-                for b in range(len(ii))
-            ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
-            logits = moe.compute_expert_scores([gx, gy], batch_experts)
-            torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
-            assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
+    try:
+        dht = tesseract.TesseractNetwork(port=tesseract.find_open_port(), start=True)
+        moe = tesseract.client.moe.RemoteMixtureOfExperts(
+            network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
+            uid_prefix='expert')
+        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
+        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
+        batch_experts = [
+            [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
+            for b in range(len(ii))
+        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        logits = moe.compute_expert_scores([gx, gy], batch_experts)
+        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
+        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
 
-            for b in range(len(ii)):
-                for e in range(len(ii[b])):
-                    assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
-        finally:
-            dht.shutdown()
+        for b in range(len(ii)):
+            for e in range(len(ii[b])):
+                assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
+    finally:
+        dht.shutdown()
 
 
 if __name__ == '__main__':