|
@@ -85,18 +85,19 @@ class TrainingAverager(DecentralizedAverager):
|
|
if len(averaged_tensors) != len(local_tensors):
|
|
if len(averaged_tensors) != len(local_tensors):
|
|
raise RuntimeError("The number of optimized parameters should not change.")
|
|
raise RuntimeError("The number of optimized parameters should not change.")
|
|
|
|
|
|
|
|
+ self.update = []
|
|
if use_old_local_tensors:
|
|
if use_old_local_tensors:
|
|
# since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
|
|
# since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
|
|
# losing local updates that might have occurred during averaging
|
|
# losing local updates that might have occurred during averaging
|
|
for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
|
|
for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
|
|
old_local_tensors):
|
|
old_local_tensors):
|
|
- self.update = averaged_tensor.to(dtype=local_tensor.dtype,
|
|
|
|
|
|
+ self.update.append(averaged_tensor.to(dtype=local_tensor.dtype,
|
|
device=local_tensor.device) - \
|
|
device=local_tensor.device) - \
|
|
old_local_tensor.to(dtype=local_tensor.dtype,
|
|
old_local_tensor.to(dtype=local_tensor.dtype,
|
|
- device=local_tensor.device)
|
|
|
|
|
|
+ device=local_tensor.device))
|
|
else:
|
|
else:
|
|
for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
|
|
for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
|
|
- self.update = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
|
|
|
|
|
|
+ self.update.append(averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device))
|
|
|
|
|
|
self.local_step += 1
|
|
self.local_step += 1
|
|
self.averaging_ready_event.set()
|
|
self.averaging_ready_event.set()
|