Browse Source

switch to global thread pool

justheuristic 5 years ago
parent
commit
5ad014ce43
1 changed files with 8 additions and 5 deletions
  1. 8 5
      tesseract/client/moe.py

+ 8 - 5
tesseract/client/moe.py

@@ -9,7 +9,7 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 from .expert import RemoteExpert, _RemoteModuleCall
 from .expert import RemoteExpert, _RemoteModuleCall
-from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
+from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
 from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
 
 
@@ -33,17 +33,20 @@ class RemoteMixtureOfExperts(nn.Module):
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
      Any expert that didn't manage to return output after that delay is considered unavailable
      Any expert that didn't manage to return output after that delay is considered unavailable
     :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
     :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
+    :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
+     allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
+     allow_broadcasting=False will raise an error
     """
     """
     def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
     def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
                  k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
                  k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
-                 uid_prefix='', expert_padding=None):
+                 uid_prefix='', expert_padding=None, allow_broadcasting=True):
         super().__init__()
         super().__init__()
         self.network, self.grid_size = network, grid_size
         self.network, self.grid_size = network, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
         self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
+        self.allow_broadcasting = allow_broadcasting
 
 
-        self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
 
 
         # grab some expert to set ensemble output shape
         # grab some expert to set ensemble output shape
@@ -121,13 +124,13 @@ class RemoteMixtureOfExperts(nn.Module):
             # select k best candidates according to scores but only those that are still active
             # select k best candidates according to scores but only those that are still active
             new_order = np.argsort(- new_scores, axis=-1)
             new_order = np.argsort(- new_scores, axis=-1)
             top_alive_lookups = [
             top_alive_lookups = [
-                self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
+                run_in_background(self.network.first_k_active, cands[order], k_best, **kwargs)
                 for cands, order in zip(new_candidates, new_order)]
                 for cands, order in zip(new_candidates, new_order)]
 
 
             batch_cand_to_score = [
             batch_cand_to_score = [
                 dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
                 dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
 
 
-            top_alive_prefixes = [result.get() for result in top_alive_lookups]
+            top_alive_prefixes = [result.result() for result in top_alive_lookups]
             top_alive_scores = [list(map(cand_to_score.get, top_cands))
             top_alive_scores = [list(map(cand_to_score.get, top_cands))
                                 for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
                                 for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]