justheuristic 3 年 前
コミット
4e4769db70
2 ファイル変更8 行追加2 行削除
  1. 7 1
      lib/training/tpu.py
  2. 1 1
      task.py

+ 7 - 1
lib/training/tpu.py

@@ -108,14 +108,17 @@ class TPUManager(mp.Process):
             xm.rendezvous('before_step')
             if xm.is_master_ordinal():
                 self.step_triggered.clear()
-
+            print("BEGAN STEP")
             if bool(self.should_load_parameters.value):
+                print("LOADING FROM HOST")
                 with self.lock if xm.is_master_ordinal() else nullcontext():
                     self._synchronizer.send_params_to_device(model)
                     self.should_load_parameters.value = False
+                    print("LOADED FROM HOST")
 
             ### compute loss and gradients
             loss = 0.0
+            print("BEFORE LOSS")
             for i in range(self.grad_accumulation_steps):
                 inputs = next(data_loader_iter)
                 outputs = model(**inputs)
@@ -126,12 +129,15 @@ class TPUManager(mp.Process):
                 del inputs, outputs, loss_i
 
             ### aggregate gradients from TPUs
+            print("aggregate gradients from TPUs")
             with self.lock if xm.is_master_ordinal() else nullcontext():
                 self._synchronizer.aggregate_grads_on_host(model, add=True)
+            print("zero gradients")
             # clear aggregated gradients from all devices
             model.zero_grad()
 
             ### accumulate statistics to host
+            print("stats to host")
             loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0)
             xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,))
 

+ 1 - 1
task.py

@@ -67,7 +67,7 @@ class TrainingTask:
         if latest_checkpoint_dir is None:
             logger.info(f"Creating model")
 
-            depth = 16#TODO
+            depth = 64
             attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
             attn_types.append('conv_like')
             shared_layer_ids = list(islice(cycle(range(4)), depth - 1))