|
@@ -9,7 +9,7 @@ import torch
|
|
|
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.utils import nested_flatten, nested_pack, get_logger
|
|
|
-
|
|
|
+from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
@@ -35,7 +35,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
|
|
|
average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
- initialize_optimizer: bool = True, **kwargs):
|
|
|
+ initialize_optimizer: bool = True, compression_type=None, **kwargs):
|
|
|
|
|
|
self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
|
|
|
self.opt_statistics = tuple(average_opt_statistics)
|
|
@@ -47,7 +47,18 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
|
|
|
with torch.no_grad():
|
|
|
averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
|
|
|
- super().__init__(averaged_tensors=averaged_tensors, **kwargs)
|
|
|
+
|
|
|
+ assert average_parameters and average_gradients and not average_opt_statistics
|
|
|
+ params = averaged_tensors[:len(averaged_tensors) // 2]
|
|
|
+ grads = averaged_tensors[len(averaged_tensors) // 2:]
|
|
|
+ compression_type = [CompressionType.FLOAT16 for p in params]
|
|
|
+ compression_type.extend(
|
|
|
+ [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in grads])
|
|
|
+
|
|
|
+ for g in grads:
|
|
|
+ print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
|
|
|
+
|
|
|
+ super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)
|
|
|
|
|
|
def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
|
|
|
"""
|