Преглед изворни кода

uint8 uniform grad compression

justheuristic пре 4 година
родитељ
комит
fa172efeb3
1 измењених фајлова са 14 додато и 3 уклоњено
  1. 14 3
      hivemind/averaging/training.py

+ 14 - 3
hivemind/averaging/training.py

@@ -9,7 +9,7 @@ import torch
 
 from hivemind.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger
-
+from hivemind.proto.runtime_pb2 import CompressionType
 logger = get_logger(__name__)
 
 
@@ -35,7 +35,7 @@ class TrainingAverager(DecentralizedAverager):
 
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
                  average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
-                 initialize_optimizer: bool = True, **kwargs):
+                 initialize_optimizer: bool = True, compression_type=None, **kwargs):
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
@@ -47,7 +47,18 @@ class TrainingAverager(DecentralizedAverager):
 
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        assert average_parameters and average_gradients and not average_opt_statistics
+        params = averaged_tensors[:len(averaged_tensors) // 2]
+        grads = averaged_tensors[len(averaged_tensors) // 2:]
+        compression_type = [CompressionType.FLOAT16 for p in params]
+        compression_type.extend(
+            [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in grads])
+
+        for g in grads:
+            print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
+
+        super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)
 
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
         """