Parcourir la source

Prefetch while reading rpc_aggregate_part() outputs (#370)

* Prefetch while reading rpc_aggregate_part() outputs
* Set default prefetch = 5
Alexander Borzunov il y a 4 ans
Parent
commit
a504a0784f
2 fichiers modifiés avec 8 ajouts et 4 suppressions
  1. 7 3
      hivemind/averaging/allreduce.py
  2. 1 1
      hivemind/averaging/partition.py

+ 7 - 3
hivemind/averaging/allreduce.py

@@ -153,13 +153,17 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            loop = asyncio.get_event_loop()
             code = None
             stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, msg in aenumerate(stream):
+            async for part_index, (averaged_part_delta, msg) in aenumerate(
+                amap_in_executor(
+                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
+                    stream,
+                    max_prefetch=self.tensor_part_container.prefetch,
+                )
+            ):
                 if code is None:
                     code = msg.code
-                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
 
             if code != averaging_pb2.AVERAGED_PART:

+ 1 - 1
hivemind/averaging/partition.py

@@ -33,7 +33,7 @@ class TensorPartContainer:
         peer_fractions: Sequence[float],
         compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
-        prefetch: int = 1,
+        prefetch: int = 5,
     ):
         if not isinstance(compression_type, Sequence):
             compression_type = [compression_type] * len(tensors)