فهرست منبع

Extra debug prints

Max Ryabinin 3 سال پیش
والد
کامیت
01273b8241
2فایلهای تغییر یافته به همراه3 افزوده شده و 1 حذف شده
  1. 2 0
      hivemind/moe/client/balancer.py
  2. 1 1
      hivemind/moe/server/task_pool.py

+ 2 - 0
hivemind/moe/client/balancer.py

@@ -110,6 +110,8 @@ class ExpertBalancer:
                     self.uid_to_queue.pop(uid, None)
                     self.throughputs.pop(uid, None)
                 if self.uid_to_queue.get(uid) != heap_entry:
+                    logger.debug(f"Skipping expert {uid} "
+                                 f"(uid_to_queue={self.uid_to_queue.get(uid)}, entry={heap_entry})")
                     continue  # skip uids that are banned or expired
 
                 if self.throughputs[uid].num_updates != 0:

+ 1 - 1
hivemind/moe/server/task_pool.py

@@ -179,7 +179,7 @@ class TaskPool(TaskPoolBase):
             # save batch futures, _output_loop will deliver on them later
             pending_batches[batch_index] = batch_tasks
 
-            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
             batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
             batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]