|
@@ -59,7 +59,7 @@ class LoadBalancedExpert(nn.Module):
|
|
|
# * available_experts: Dict[uid -> (EMA, expiration)] - experts that take part in load balancing
|
|
|
# * maintain blacklist: Dict[uid -> expiration] - experts banned until expiration for a non-response
|
|
|
# * maintain a min-heap queue of (load, rng, expert) tuples
|
|
|
- # * update_triggered, update_finished: threading.Event
|
|
|
+ # * update_triggered, update_finished: threading.Event; lock: threading.Lock
|
|
|
#
|
|
|
# update experts in background, while True:
|
|
|
# * wait for 30s or for update_triggered, whichever comes first
|
|
@@ -83,14 +83,14 @@ class LoadBalancedExpert(nn.Module):
|
|
|
# * * expert_throughput_ema, expert_expiration_time = get ema from dict
|
|
|
# * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
|
|
|
# * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
|
|
|
- # * * try:
|
|
|
- # * * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
|
|
|
- # * * * * outputs = call_forward_or_backward()
|
|
|
- # * * * expert_throughput_ema.update(measured_ema)
|
|
|
- # * * * return outputs # <--------- this is the desired exit point
|
|
|
- # * * except DidNotRespondCorrectly:
|
|
|
- # * * * banned_experts[expert] = expert_expiration_time
|
|
|
- # * * continue # try again
|
|
|
+ # * try:
|
|
|
+ # * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
|
|
|
+ # * * * outputs = call_forward_or_backward()
|
|
|
+ # * * expert_throughput_ema.update(measured_ema)
|
|
|
+ # * * return outputs # <--------- this is the desired exit point
|
|
|
+ # * except DidNotRespondCorrectly:
|
|
|
+ # * * banned_experts[expert] = expert_expiration_time
|
|
|
+ # * continue # try again
|
|
|
|
|
|
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|