ソースを参照

instructions to test distributed inference

justheuristic 3 年 前
コミット
2d55e6e4fe
1 ファイル変更5 行追加4 行削除
  1. 5 4
      README.md

+ 5 - 4
README.md

@@ -57,14 +57,15 @@ dht = hivemind.DHT(
     client_mode=True, start=True,
 )
 
-m, = get_remote_module(dht, ['bloom6b3.0'])
+layer0, layer1 = get_remote_module(dht, ['bloom6b3.0', 'bloom6b3.1'])
 
-# test forward/backward, one block
-outputs = m(torch.randn(1, 128, 4096))
+# test forward/backward, two blocks
+outputs, = layer1(*layer0(torch.randn(1, 64, 4096)))
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
-with m.begin_inference_session() as sess:
+# test inference, one block
+with layer0.begin_inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```