瀏覽代碼

Prefetch on reading rpc_aggregate_part() outputs

Aleksandr Borzunov 4 年之前
父節點
當前提交
87bd95e971
共有 1 個文件被更改,包括 12 次插入8 次删除
  1. 12 8
      hivemind/averaging/allreduce.py

+ 12 - 8
hivemind/averaging/allreduce.py

@@ -153,18 +153,22 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            loop = asyncio.get_event_loop()
-            code = None
+            last_code = None
             stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, msg in aenumerate(stream):
-                if code is None:
-                    code = msg.code
-                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+            async for part_index, (averaged_part_delta, part_code) in aenumerate(
+                amap_in_executor(
+                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.code),
+                    stream,
+                    max_prefetch=self.tensor_part_container.prefetch,
+                )
+            ):
+                if last_code is None:
+                    last_code = part_code
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
 
-            if code != averaging_pb2.AVERAGED_PART:
+            if last_code != averaging_pb2.AVERAGED_PART:
                 raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(last_code)} "
                     f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
                     f", allreduce failed"
                 )