Эх сурвалжийг харах

switch to global thread pool

justheuristic 5 жил өмнө
parent
commit
5ad014ce43

+ 8 - 5
tesseract/client/moe.py

@@ -9,7 +9,7 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from .expert import RemoteExpert, _RemoteModuleCall
-from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
+from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
 from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
@@ -33,17 +33,20 @@ class RemoteMixtureOfExperts(nn.Module):
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
      Any expert that didn't manage to return output after that delay is considered unavailable
     :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
+    :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
+     allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
+     allow_broadcasting=False will raise an error
     """
     def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
                  k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
-                 uid_prefix='', expert_padding=None):
+                 uid_prefix='', expert_padding=None, allow_broadcasting=True):
         super().__init__()
         self.network, self.grid_size = network, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
+        self.allow_broadcasting = allow_broadcasting
 
-        self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
 
         # grab some expert to set ensemble output shape
@@ -121,13 +124,13 @@ class RemoteMixtureOfExperts(nn.Module):
             # select k best candidates according to scores but only those that are still active
             new_order = np.argsort(- new_scores, axis=-1)
             top_alive_lookups = [
-                self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
+                run_in_background(self.network.first_k_active, cands[order], k_best, **kwargs)
                 for cands, order in zip(new_candidates, new_order)]
 
             batch_cand_to_score = [
                 dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
 
-            top_alive_prefixes = [result.get() for result in top_alive_lookups]
+            top_alive_prefixes = [result.result() for result in top_alive_lookups]
             top_alive_scores = [list(map(cand_to_score.get, top_cands))
                                 for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]