|
@@ -4,6 +4,7 @@ import faulthandler
|
|
|
import math
|
|
|
import torch
|
|
|
import multiprocessing as mp
|
|
|
+import numpy as np
|
|
|
|
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
|
|
|
@@ -56,20 +57,23 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
accumulate_grads_on: Optional[torch.device] = None,
|
|
|
client_mode: bool = None,
|
|
|
warn: bool = True,
|
|
|
+ min_comprasion_ratio: float = 0.5,
|
|
|
+ grad_extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
**kwargs,
|
|
|
):
|
|
|
self.rank = averager_rank
|
|
|
self.parameters = tuple(parameters)
|
|
|
- self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
|
|
|
- self._gs = list(
|
|
|
- torch.zeros_like(grad, device="cpu")
|
|
|
- for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
+ self._uncompressed_gradients = set(
|
|
|
+ i for i, grad in enumerate(self._grads_from_parameters())
|
|
|
+ if len(tuple(grad.size())) == 1 or
|
|
|
+ (self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio)
|
|
|
)
|
|
|
+ self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
|
|
|
self._qs = list(
|
|
|
torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
|
|
|
for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
)
|
|
|
- for tensor in (self._qs + self._gs):
|
|
|
+ for tensor in (self._qs + self._gradient_rests):
|
|
|
if tensor is not None:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|
|
@@ -82,6 +86,7 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
accumulate_grads_on=accumulate_grads_on,
|
|
|
client_mode=client_mode,
|
|
|
warn=warn,
|
|
|
+ grad_extra_tensors=grad_extra_tensors,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
@@ -104,16 +109,11 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
- async def _dump_later():
|
|
|
- await asyncio.sleep(15.0)
|
|
|
- print([*map(asyncio.Task.print_stack, asyncio.Task.all_tasks())])
|
|
|
- # task = asyncio.create_task(_dump_later())
|
|
|
try:
|
|
|
bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
- # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
|
download_bandwidths = [
|
|
|
thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
|
|
|
]
|
|
@@ -123,14 +123,12 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
|
|
|
async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
|
compressed_tensors = [lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
|
|
|
- cs = [torch.zeros_like(grad, device="cpu") for grad in compressed_tensors]
|
|
|
- for c, g, cg in zip(cs, self._gs, compressed_tensors):
|
|
|
- torch.sub(cg, g, out=c)
|
|
|
|
|
|
+ cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
|
|
|
ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
|
|
|
- for p, q, c in zip(ps, self._qs, cs):
|
|
|
- torch.matmul(c.reshape(-1, q.size(0)), q, out=p)
|
|
|
- first_all_reduced = ps + [lt for idx, lt in enumerate(local_tensors) if idx in self._uncompressed_gradients]
|
|
|
+ for p, q, rest in zip(ps, self._qs, cs):
|
|
|
+ torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
|
|
|
+ first_all_reduced = ps + [rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients]
|
|
|
allreduce1 = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
@@ -152,7 +150,7 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
else:
|
|
|
async for _ in allreduce1: # trigger all-reduce by iterating
|
|
|
raise ValueError("aux peers should not receive averaged tensors")
|
|
|
-
|
|
|
+
|
|
|
# orth ps
|
|
|
for p in ps:
|
|
|
orthogonalize(p)
|
|
@@ -190,8 +188,8 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
new_c = torch.matmul(p, q.t())
|
|
|
c.copy_(new_c.reshape(c.size()))
|
|
|
|
|
|
- for c, g in zip(cs, self._gs):
|
|
|
- torch.add(g, c, out=g)
|
|
|
+ for rest, lt in zip(self._gradient_rests, local_tensors):
|
|
|
+ torch.add(lt, rest, out=lt)
|
|
|
|
|
|
return allreduce1.gathered
|
|
|
except BaseException as e:
|
|
@@ -200,27 +198,14 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
finally:
|
|
|
pass
|
|
|
|
|
|
- @contextlib.contextmanager
|
|
|
@torch.no_grad()
|
|
|
- def use_averaged_gradients(self):
|
|
|
- self._new_averaged_grads = False
|
|
|
+ def load_accumulators_into_averager_(self):
|
|
|
+ """load locally accumulated gradients into the averager for aggregation"""
|
|
|
+ # divide locally accumulated gradients by the number of times they were accumulated
|
|
|
+ grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
|
|
|
with self.get_tensors() as averaged_grads:
|
|
|
- compressed_tensors = [lt for idx, lt in enumerate(averaged_grads) if idx not in self._uncompressed_gradients]
|
|
|
- old_averaged = [torch.zeros_like(lt) for lt in compressed_tensors]
|
|
|
- for g, cg, oag in zip(self._gs, compressed_tensors, old_averaged):
|
|
|
- oag.copy_(cg)
|
|
|
- cg.copy_(g)
|
|
|
- try:
|
|
|
- assert len(averaged_grads) == len(self.parameters)
|
|
|
- old_grads = [param.grad for param in self.parameters]
|
|
|
- for param, new_grad in zip(self.parameters, averaged_grads):
|
|
|
- param.grad.copy_(new_grad)
|
|
|
- yield
|
|
|
- finally:
|
|
|
- for param, old_grad in zip(self.parameters, old_grads):
|
|
|
- param.grad.copy_(old_grad)
|
|
|
- for cg, oag in zip(compressed_tensors, old_averaged):
|
|
|
- cg.copy_(oag)
|
|
|
+ for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_rests):
|
|
|
+ torch.sub(grad_acc * grad_scale, averaged_grad, out=rest)
|
|
|
|
|
|
|
|
|
@torch.jit.script
|