|
@@ -23,6 +23,7 @@ if __name__ == "__main__":
|
|
|
parser.add_argument('--port', type=int, default=None, required=False)
|
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
|
|
|
parser.add_argument('--no_network', action='store_true')
|
|
|
+ parser.add_argument('--no_optimizer', action='store_true')
|
|
|
parser.add_argument('--initial_peers', type=str, default="[]", required=False)
|
|
|
parser.add_argument('--network_port', type=int, default=None, required=False)
|
|
|
parser.add_argument('--lifetime_seconds', type=int, default=None, required=False)
|
|
@@ -55,9 +56,9 @@ if __name__ == "__main__":
|
|
|
experts = {}
|
|
|
for i in range(args.num_experts):
|
|
|
expert = torch.jit.script(layers.name_to_block[args.expert_cls](args.hidden_dim))
|
|
|
+ opt = torch.optim.SGD(expert.parameters(), 0.0) if args.no_optimizer else torch.optim.Adam(expert.parameters())
|
|
|
expert_uid = f'{args.expert_prefix}.{i + args.expert_offset}'
|
|
|
- experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid,
|
|
|
- expert=expert, opt=torch.optim.Adam(expert.parameters()),
|
|
|
+ experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
|
|
|
args_schema=(tesseract.BatchTensorProto(args.hidden_dim),),
|
|
|
outputs_schema=tesseract.BatchTensorProto(args.hidden_dim),
|
|
|
max_batch_size=args.max_batch_size,
|