justheuristic 4 年之前
父節點
當前提交
e067f02cb9
共有 1 個文件被更改,包括 9 次插入9 次删除
  1. 9 9
      hivemind/moe/client/balanced_expert.py

+ 9 - 9
hivemind/moe/client/balanced_expert.py

@@ -59,7 +59,7 @@ class LoadBalancedExpert(nn.Module):
             # * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
             # * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
             # * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
             # * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
             # * maintain a min-heap queue of (load, rng, expert) tuples
             # * maintain a min-heap queue of (load, rng, expert) tuples
-            # * update_triggered, update_finished: threading.Event
+            # * update_triggered, update_finished: threading.Event; lock: threading.Lock
             #
             #
             # update experts in background, while True:
             # update experts in background, while True:
             # * wait for 30s or for update_triggered, whichever comes first
             # * wait for 30s or for update_triggered, whichever comes first
@@ -83,14 +83,14 @@ class LoadBalancedExpert(nn.Module):
             # * * expert_throughput_ema, expert_expiration_time = get ema from dict
             # * * expert_throughput_ema, expert_expiration_time = get ema from dict
             # * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
             # * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
             # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
             # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
-            # * * try:
-            # * * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
-            # * * * * outputs = call_forward_or_backward()
-            # * * * expert_throughput_ema.update(measured_ema)
-            # * * * return outputs      # <--------- this is the desired exit point
-            # * * except DidNotRespondCorrectly:
-            # * * * banned_experts[expert] = expert_expiration_time
-            # * * continue # try again
+            # * try:
+            # * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
+            # * * * outputs = call_forward_or_backward()
+            # * * expert_throughput_ema.update(measured_ema)
+            # * * return outputs      # <--------- this is the desired exit point
+            # * except DidNotRespondCorrectly:
+            # * * banned_experts[expert] = expert_expiration_time
+            # * continue # try again
 
 
     def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
     def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
         """