|
@@ -32,7 +32,7 @@ def main():
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
def benchmark_forward(process_idx, args):
|
|
def benchmark_forward(process_idx, args):
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
- model = DistributedBloomForCausalLM.from_pretrained(args.model)#, initial_peers=args.initial_peers)
|
|
|
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(args.model, initial_peers=args.initial_peers)
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
torch.manual_seed(42)
|
|
torch.manual_seed(42)
|