|
@@ -82,7 +82,7 @@ class LoadBalancedExpert(nn.Module):
|
|
|
# * * load, _, expert = queue.heappop_min()
|
|
|
# * * expert_throughput_ema, expert_expiration_time = get ema from dict
|
|
|
# * * task_complexity = batch_size * 1.5 if forward else 2.5 # if backward
|
|
|
- # * * queue.heappush (load + load_coefficient / expert_throughput_ema, new_rng, expert)
|
|
|
+ # * * queue.heappush (load + task_complexity / expert_throughput_ema, new_rng, expert)
|
|
|
# * * try:
|
|
|
# * * * with measure_ema(start=now, batch_size=batch_size) as measured_ema:
|
|
|
# * * * * outputs = call_forward_or_backward()
|