|
@@ -1,6 +1,7 @@
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
import multiprocessing as mp
|
|
|
+from enum import Enum
|
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
|
|
|
|
import torch
|
|
@@ -21,6 +22,11 @@ GatheredData = Any
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
+class AllReducePhases(Enum):
|
|
|
+ PHASE_P = 1
|
|
|
+ PHASE_Q = 2
|
|
|
+
|
|
|
+
|
|
|
class PowerSGDGradientAverager(GradientAverager):
|
|
|
"""
|
|
|
A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
|
|
@@ -97,8 +103,6 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
if idx not in self._uncompressed_gradients_indexes
|
|
|
]
|
|
|
|
|
|
- self.all_reduce_phases = (b".phase_p", b".phase_q")
|
|
|
-
|
|
|
super().__init__(
|
|
|
self.parameters,
|
|
|
dht=dht,
|
|
@@ -107,23 +111,23 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
accumulate_grads_on=accumulate_grads_on,
|
|
|
client_mode=client_mode,
|
|
|
warn=warn,
|
|
|
- averaged_grads=None,
|
|
|
+ averaged_grads=averaged_grads,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def _register_allreduce_group(self, group_info: GroupInfo):
|
|
|
- """registers a given all-reduce runner to listen for incoming connections"""
|
|
|
+ """Register a given group for one or more all-reduce rounds"""
|
|
|
try:
|
|
|
- for phase in self.all_reduce_phases:
|
|
|
- self._running_groups[group_info.group_id + phase] = asyncio.Future()
|
|
|
+ for phase in list(AllReducePhases):
|
|
|
+ self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
|
|
|
self._pending_groups_registered.set()
|
|
|
yield
|
|
|
finally:
|
|
|
- for phase in self.all_reduce_phases:
|
|
|
- maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
|
|
|
+ for phase in list(AllReducePhases):
|
|
|
+ maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
|
|
|
if maybe_future and not maybe_future.done():
|
|
|
- logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
|
|
|
+ logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
|
|
|
self._pending_groups_registered.set()
|
|
|
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
@@ -149,17 +153,17 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
|
|
|
ps = [
|
|
|
torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
|
|
|
- for idx, grad in enumerate(averaged_grad_via_sgd)
|
|
|
+ for idx, grad in enumerate(averaged_grads_via_sgd)
|
|
|
]
|
|
|
for p, q, m in zip(ps, self._qs, self._ms):
|
|
|
# we use reshape for all matrixes because PowerSGD works only with 2d tensors
|
|
|
torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
|
|
|
|
|
|
- allreduce_p_phase = AllReduceRunner(
|
|
|
+ allreduce_phase_p = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
|
prefix=self.prefix,
|
|
|
- group_id=group_info.group_id + self.all_reduce_phases[0],
|
|
|
+ group_id=group_info.group_id + AllReducePhases.PHASE_P.name.encode(),
|
|
|
tensors=ps,
|
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
|
peer_fractions=peer_fractions,
|
|
@@ -167,14 +171,14 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
modes=modes,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce_p_phase)
|
|
|
+ self._running_groups[group_info.group_id + AllReducePhases.PHASE_P.name.encode()].set_result(allreduce_phase_p)
|
|
|
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce_p_phase):
|
|
|
+ async for tensor, update in azip(as_aiter(*ps), allreduce_phase_p):
|
|
|
# all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
else:
|
|
|
- async for _ in allreduce_p_phase: # trigger all-reduce by iterating
|
|
|
+ async for _ in allreduce_phase_p: # trigger all-reduce by iterating
|
|
|
raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
for p in ps:
|
|
@@ -187,11 +191,11 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
|
|
|
]
|
|
|
|
|
|
- allreduce_q_phase = AllReduceRunner(
|
|
|
+ allreduce_phase_q = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
|
prefix=self.prefix,
|
|
|
- group_id=group_info.group_id + self.all_reduce_phases[1],
|
|
|
+ group_id=group_info.group_id + AllReducePhases.PHASE_Q.name.encode(),
|
|
|
tensors=self._qs + averaged_grad_wo_sgd,
|
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
|
peer_fractions=peer_fractions,
|
|
@@ -199,28 +203,31 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
modes=modes,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce_q_phase)
|
|
|
+ self._running_groups[group_info.group_id + AllReducePhases.PHASE_Q.name.encode()].set_result(allreduce_phase_q)
|
|
|
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- async for tensor, update in azip(as_aiter(*(self._qs + averaged_grad_wo_sgd)), allreduce_q_phase):
|
|
|
+ async for tensor, update in azip(as_aiter(*(self._qs + averaged_grad_wo_sgd)), allreduce_phase_q):
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
self.last_updated = get_dht_time()
|
|
|
self._state_updated.set()
|
|
|
else:
|
|
|
- async for _ in allreduce_q_phase:
|
|
|
+ async for _ in allreduce_phase_q:
|
|
|
raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
- for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grad_via_sgd):
|
|
|
+ for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
|
|
|
new_m = torch.matmul(p, q.t()).reshape(m.size())
|
|
|
m.sub_(new_m)
|
|
|
grad.copy_(new_m)
|
|
|
|
|
|
- return allreduce1.gathered
|
|
|
+ return allreduce_phase_p.gathered
|
|
|
except BaseException as e:
|
|
|
logger.exception(e)
|
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
|
|
def get_current_state(self):
|
|
|
+ """
|
|
|
+ Get current gradient averager state and when requested by a newbie peer.
|
|
|
+ """
|
|
|
with torch.no_grad(), self.lock_averaged_tensors:
|
|
|
grad_averager_buffers = [q for q in self._qs]
|
|
|
grad_averager_buffers_infos = [
|
|
@@ -232,6 +239,10 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
return metadata, grad_averager_buffers, grad_averager_buffers_infos
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
+ """
|
|
|
+ Attempt to download the latest optimizer state from peers and update gradient averager buffers.
|
|
|
+ :returns: whether or the averager succeeded in loading parameters
|
|
|
+ """
|
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
|
if loaded_state is None:
|
|
|
return
|