justheuristic 4 ani în urmă
părinte
comite
416177f4ae
1 a modificat fișierele cu 14 adăugiri și 1 ștergeri
  1. 14 1
      hivemind/client/averaging/training.py

+ 14 - 1
hivemind/client/averaging/training.py

@@ -10,6 +10,8 @@ from hivemind.client.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 
 logger = get_logger(__name__)
+from hivemind.proto.runtime_pb2 import CompressionType
+
 
 
 class TrainingAverager(DecentralizedAverager):
@@ -45,7 +47,18 @@ class TrainingAverager(DecentralizedAverager):
 
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        assert average_parameters and average_gradients and not average_opt_statistics
+        params = averaged_tensors[:len(averaged_tensors) // 2]
+        grads = averaged_tensors[len(averaged_tensors) // 2:]
+        compression_type = [CompressionType.FLOAT16 for p in params]
+        compression_type.extend(
+            [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in grads])
+
+        for g in grads:
+            print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
+
+        super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)
 
     @torch.no_grad()
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):