소스 검색

Fix TensorRole

Artem Chumachenko 3 년 전
부모
커밋
e7dc9e0da3
1개의 변경된 파일2개의 추가작업 그리고 1개의 파일을 삭제
  1. 2 1
      hivemind/optim/power_sgd_averager.py

+ 2 - 1
hivemind/optim/power_sgd_averager.py

@@ -21,6 +21,7 @@ from hivemind.compression import (
     NoCompression,
     deserialize_torch_tensor,
     serialize_torch_tensor,
+    TensorRole,
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
@@ -217,7 +218,7 @@ class PowerSGDGradientAverager(GradientAverager):
         with torch.no_grad(), self.lock_averaged_tensors:
             grad_averager_buffers = list(q for q in self._qs)
             grad_averager_buffers_infos = [
-                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.PARAMETER)
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
                 for buffer, key in zip(grad_averager_buffers, range(len(grad_averager_buffers)))
             ]