justheuristic 3 anos atrás
pai
commit
53cf4b6b8b
1 arquivos alterados com 9 adições e 4 exclusões
  1. 9 4
      tests/test_optimizer.py

+ 9 - 4
tests/test_optimizer.py

@@ -168,12 +168,17 @@ def test_load_state_from_peers(dpu: bool):
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, extra_tensors=extras1,
-        **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        extra_tensors=extras1,
+        **common_kwargs,
     )
 
-    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True,  extra_tensors=extras2,
-                                  **common_kwargs)
+    avgr2 = TrainingStateAverager(
+        dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2, **common_kwargs
+    )
 
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42