|
@@ -57,14 +57,15 @@ dht = hivemind.DHT(
|
|
client_mode=True, start=True,
|
|
client_mode=True, start=True,
|
|
)
|
|
)
|
|
|
|
|
|
-m, = get_remote_module(dht, ['bloom6b3.0'])
|
|
|
|
|
|
+layer0, layer1 = get_remote_module(dht, ['bloom6b3.0', 'bloom6b3.1'])
|
|
|
|
|
|
-# test forward/backward, one block
|
|
|
|
-outputs = m(torch.randn(1, 128, 4096))
|
|
|
|
|
|
+# test forward/backward, two blocks
|
|
|
|
+outputs, = layer1(*layer0(torch.randn(1, 64, 4096)))
|
|
loss = (outputs * torch.randn_like(outputs)).norm()
|
|
loss = (outputs * torch.randn_like(outputs)).norm()
|
|
loss.backward()
|
|
loss.backward()
|
|
|
|
|
|
-with m.begin_inference_session() as sess:
|
|
|
|
|
|
+# test inference, one block
|
|
|
|
+with layer0.begin_inference_session() as sess:
|
|
for i in range(10):
|
|
for i in range(10):
|
|
res = sess.step(torch.ones(1, 1, 4096))
|
|
res = sess.step(torch.ones(1, 1, 4096))
|
|
```
|
|
```
|