|
@@ -51,7 +51,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
assert average_parameters and average_gradients and not average_opt_statistics
|
|
|
|
|
|
compression_type = [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT
|
|
|
- for g in averaged_tensors])
|
|
|
+ for g in averaged_tensors]
|
|
|
|
|
|
for g in averaged_tensors:
|
|
|
print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
|