|
@@ -21,6 +21,7 @@ from hivemind.compression import (
|
|
|
NoCompression,
|
|
|
deserialize_torch_tensor,
|
|
|
serialize_torch_tensor,
|
|
|
+ TensorRole,
|
|
|
)
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
@@ -217,7 +218,7 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
with torch.no_grad(), self.lock_averaged_tensors:
|
|
|
grad_averager_buffers = list(q for q in self._qs)
|
|
|
grad_averager_buffers_infos = [
|
|
|
- CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.PARAMETER)
|
|
|
+ CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
|
|
|
for buffer, key in zip(grad_averager_buffers, range(len(grad_averager_buffers)))
|
|
|
]
|
|
|
|