|
@@ -49,13 +49,11 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
|
|
|
|
|
|
assert average_parameters and average_gradients and not average_opt_statistics
|
|
|
- params = averaged_tensors[:len(averaged_tensors) // 2]
|
|
|
- grads = averaged_tensors[len(averaged_tensors) // 2:]
|
|
|
- compression_type = [CompressionType.FLOAT16 for p in params]
|
|
|
+
|
|
|
compression_type.extend(
|
|
|
- [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in grads])
|
|
|
+ [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in averaged_tensors])
|
|
|
|
|
|
- for g in grads:
|
|
|
+ for g in averaged_tensors:
|
|
|
print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
|
|
|
|
|
|
super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)
|