justheuristic 4 жил өмнө
parent
commit
565ddbe25e

+ 2 - 2
hivemind/averaging/training.py

@@ -50,8 +50,8 @@ class TrainingAverager(DecentralizedAverager):
 
         assert average_parameters and average_gradients and not average_opt_statistics
 
-        compression_type.extend(
-            [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in averaged_tensors])
+        compression_type = [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT
+                            for g in averaged_tensors])
 
         for g in averaged_tensors:
             print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')