|
@@ -57,6 +57,7 @@ class ExpertBackend:
|
|
num_warmup_steps: int = None,
|
|
num_warmup_steps: int = None,
|
|
num_total_steps: int = None,
|
|
num_total_steps: int = None,
|
|
clip_grad_norm: float = None,
|
|
clip_grad_norm: float = None,
|
|
|
|
+ target_batch_size: int = None,
|
|
**kwargs,
|
|
**kwargs,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -98,6 +99,7 @@ class ExpertBackend:
|
|
|
|
|
|
self.update_count = 0
|
|
self.update_count = 0
|
|
self.examples_processed = 0
|
|
self.examples_processed = 0
|
|
|
|
+ self.target_batch_size = target_batch_size
|
|
|
|
|
|
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
"""
|
|
"""
|
|
@@ -182,21 +184,23 @@ class ExpertBackend:
|
|
"""
|
|
"""
|
|
Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
|
|
Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
|
|
"""
|
|
"""
|
|
|
|
+ self.examples_processed += batch_size
|
|
|
|
+
|
|
if self.clip_grad_norm is not None:
|
|
if self.clip_grad_norm is not None:
|
|
torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
|
|
torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
|
|
|
|
|
|
if isinstance(self.optimizer, hivemind.CollaborativeOptimizer):
|
|
if isinstance(self.optimizer, hivemind.CollaborativeOptimizer):
|
|
self.optimizer.step(batch_size)
|
|
self.optimizer.step(batch_size)
|
|
else:
|
|
else:
|
|
- self.optimizer.step()
|
|
|
|
- self.optimizer.zero_grad()
|
|
|
|
|
|
+ if self.target_batch_size is None or self.examples_processed % self.target_batch_size == 0:
|
|
|
|
+ self.optimizer.step()
|
|
|
|
+ self.optimizer.zero_grad()
|
|
|
|
|
|
- if self.scheduler is not None:
|
|
|
|
- self.scheduler.step()
|
|
|
|
|
|
+ if self.scheduler is not None:
|
|
|
|
+ self.scheduler.step()
|
|
|
|
|
|
# TODO update_count is not always incremented if CollaborativeOptimizer is used
|
|
# TODO update_count is not always incremented if CollaborativeOptimizer is used
|
|
self.update_count += 1
|
|
self.update_count += 1
|
|
- self.examples_processed += batch_size
|
|
|
|
|
|
|
|
def get_stats(self) -> Dict:
|
|
def get_stats(self) -> Dict:
|
|
"""
|
|
"""
|