Michael Diskin vor 3 Jahren
Ursprung
Commit
ed3ff0cfb0
2 geänderte Dateien mit 13 neuen und 2 gelöschten Zeilen
  1. 2 2
      hivemind/averaging/partition.py
  2. 11 0
      hivemind/optim/collaborative.py

+ 2 - 2
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 16
+DEFAULT_PART_SIZE_BYTES = 2 ** 19
 
 
 class TensorPartContainer:
@@ -35,7 +35,7 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
-        prefetch: int = 5,
+        prefetch: int = 1,
     ):
         if tensor_infos is None:
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))

+ 11 - 0
hivemind/optim/collaborative.py

@@ -210,6 +210,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"]["collaborative_step"]
+            del state_dict["state"]["collaborative_step"]
+        return super().load_state_dict(state_dict)
+
     def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters