Anton Sinitsin 4 жил өмнө
parent
commit
285549443f

+ 4 - 3
hivemind/client/averaging/training.py

@@ -85,18 +85,19 @@ class TrainingAverager(DecentralizedAverager):
                     if len(averaged_tensors) != len(local_tensors):
                         raise RuntimeError("The number of optimized parameters should not change.")
 
+                    self.update = []
                     if use_old_local_tensors:
                         # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
                         # losing local updates that might have occurred during averaging
                         for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
                                                                                    old_local_tensors):
-                            self.update = averaged_tensor.to(dtype=local_tensor.dtype,
+                            self.update.append(averaged_tensor.to(dtype=local_tensor.dtype,
                                                                     device=local_tensor.device) - \
                                                  old_local_tensor.to(dtype=local_tensor.dtype,
-                                                                     device=local_tensor.device)
+                                                                     device=local_tensor.device))
                     else:
                         for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                             self.update = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+                            self.update.append(averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device))
 
             self.local_step += 1
             self.averaging_ready_event.set()