|
@@ -108,14 +108,17 @@ class TPUManager(mp.Process):
|
|
|
xm.rendezvous('before_step')
|
|
|
if xm.is_master_ordinal():
|
|
|
self.step_triggered.clear()
|
|
|
-
|
|
|
+ print("BEGAN STEP")
|
|
|
if bool(self.should_load_parameters.value):
|
|
|
+ print("LOADING FROM HOST")
|
|
|
with self.lock if xm.is_master_ordinal() else nullcontext():
|
|
|
self._synchronizer.send_params_to_device(model)
|
|
|
self.should_load_parameters.value = False
|
|
|
+ print("LOADED FROM HOST")
|
|
|
|
|
|
### compute loss and gradients
|
|
|
loss = 0.0
|
|
|
+ print("BEFORE LOSS")
|
|
|
for i in range(self.grad_accumulation_steps):
|
|
|
inputs = next(data_loader_iter)
|
|
|
outputs = model(**inputs)
|
|
@@ -126,12 +129,15 @@ class TPUManager(mp.Process):
|
|
|
del inputs, outputs, loss_i
|
|
|
|
|
|
### aggregate gradients from TPUs
|
|
|
+ print("aggregate gradients from TPUs")
|
|
|
with self.lock if xm.is_master_ordinal() else nullcontext():
|
|
|
self._synchronizer.aggregate_grads_on_host(model, add=True)
|
|
|
+ print("zero gradients")
|
|
|
# clear aggregated gradients from all devices
|
|
|
model.zero_grad()
|
|
|
|
|
|
### accumulate statistics to host
|
|
|
+ print("stats to host")
|
|
|
loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0)
|
|
|
xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,))
|
|
|
|