Forráskód Böngészése

uint8 uniform param compression

justheuristic 4 éve
szülő
commit
0b6a8e82a7
1 módosított fájl, 3 hozzáadás és 5 törlés
  1. 3 5
      hivemind/averaging/training.py

+ 3 - 5
hivemind/averaging/training.py

@@ -49,13 +49,11 @@ class TrainingAverager(DecentralizedAverager):
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
 
         assert average_parameters and average_gradients and not average_opt_statistics
-        params = averaged_tensors[:len(averaged_tensors) // 2]
-        grads = averaged_tensors[len(averaged_tensors) // 2:]
-        compression_type = [CompressionType.FLOAT16 for p in params]
+
         compression_type.extend(
-            [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in grads])
+            [CompressionType.FLOAT16 if g.numel() <= 2 ** 16 else CompressionType.UNIFORM_8BIT for g in averaged_tensors])
 
-        for g in grads:
+        for g in averaged_tensors:
             print('COMPRESSION', g.shape, '->', 'FLOAT16' if g.numel() <= 2 ** 16 else 'UINT8')
 
         super().__init__(averaged_tensors=averaged_tensors, compression_type=compression_type, **kwargs)