|
@@ -26,9 +26,6 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
|
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
|
assert isinstance(remote_block, RemoteTransformerBlock)
|
|
|
|
|
|
- _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
|
- remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
|
|
|
-
|
|
|
ref_blocks = [
|
|
|
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
|
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
|
|
@@ -59,9 +56,6 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
|
assert isinstance(remote_block, RemoteTransformerBlock)
|
|
|
|
|
|
- _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
|
- remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
|
|
|
-
|
|
|
inputs = torch.randn(1, 8, config.hidden_size)
|
|
|
|
|
|
outputs_inference = []
|