123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import os
- from argparse import ArgumentParser
- from base64 import b64encode
- from azure.identity import DefaultAzureCredential
- from azure.mgmt.compute import ComputeManagementClient
- from azure.mgmt.network import NetworkManagementClient
- from azure.mgmt.resource import ResourceManagementClient
- print("=======================WARNING=======================")
- print("= The code may fail to import 'gi' but that is okay =")
- print("===================END OF WARNING====================")
- SUBSCRIPTION_ID = os.environ["SUBSCRIPTION_ID"]
- GROUP_NAME = "dalle_west2"
- NETWORK_NAME = "vnet"
- SUBNET_NAME = "subnet"
- LOCATION = "westus2"
- ADMIN_PASS = os.environ['AZURE_PASS']
- SCALE_SETS = ('worker',)
- SWARM_SIZE = 4
- WORKER_CLOUD_INIT = """#cloud-config
- package_update: true
- packages:
- - build-essential
- - wget
- - git
- - vim
- write_files:
- - path: /home/hivemind/init_worker.sh
- permissions: '0766'
- owner: root:root
- content: |
- #!/usr/bin/env bash
- set -e
- wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh
- bash install_miniconda.sh -b -p /opt/conda
- export PATH="/opt/conda/bin:${PATH}"
- conda install python~=3.8.0 pip
- conda install pytorch cudatoolkit=11.1 -c pytorch -c nvidia
- conda clean --all
- pip install https://github.com/learning-at-home/hivemind/archive/scaling_tweaks.zip
- systemctl enable testserv
- systemctl start testserv
- - path: /etc/systemd/system/testserv.service
- permissions: '0777'
- owner: root:root
- content: |
- [Unit]
- Description=One Shot
- [Service]
- ExecStart=/etc/createfile
- Type=oneshot
- RemainAfterExit=yes
- [Install]
- WantedBy=multi-user.target
- - path: /etc/createfile
- permissions: '0777'
- owner: root:root
- content: |
- #!/bin/bash
- export PATH="/opt/conda/bin:${PATH}"
- cd /home/hivemind
- ulimit -n 8192
-
- git clone https://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@github.com/learning-at-home/dalle-hivemind.git -b azure
- cd dalle-hivemind
- pip install -r requirements.txt
- pip install -U transformers==4.10.2 datasets==1.11.0
-
- WANDB_API_KEY=7cc938e45e63ef7d2f88f811be240ba0395c02dd python run_trainer.py --run_name $(hostname) \
- --experiment_prefix dalle_large_5groups \
- --initial_peers /ip4/52.232.13.142/tcp/31334/p2p/QmZLrSPKAcP4puJ8gUGvQ155thk5Q6J7oE5exMUSq1oD5i \
- --per_device_train_batch_size 1 --gradient_accumulation_steps 1
- runcmd:
- - bash /home/hivemind/init_worker.sh
- """
- def main():
- parser = ArgumentParser()
- parser.add_argument('command', choices=('create', 'delete'))
- args = parser.parse_args()
- resource_client = ResourceManagementClient(
- credential=DefaultAzureCredential(),
- subscription_id=SUBSCRIPTION_ID
- )
- network_client = NetworkManagementClient(
- credential=DefaultAzureCredential(),
- subscription_id=SUBSCRIPTION_ID
- )
- compute_client = ComputeManagementClient(
- credential=DefaultAzureCredential(),
- subscription_id=SUBSCRIPTION_ID
- )
- # Create resource group
- resource_client.resource_groups.create_or_update(
- GROUP_NAME,
- {"location": LOCATION}
- )
- # Create virtual network
- network_client.virtual_networks.begin_create_or_update(
- GROUP_NAME,
- NETWORK_NAME,
- {
- 'location': LOCATION,
- 'address_space': {
- 'address_prefixes': ['10.0.0.0/16']
- }
- }
- ).result()
- subnet = network_client.subnets.begin_create_or_update(
- GROUP_NAME,
- NETWORK_NAME,
- SUBNET_NAME,
- {'address_prefix': '10.0.0.0/16'}
- ).result()
- if args.command == 'create':
- scalesets = []
- for scaleset_name in SCALE_SETS:
- cloud_init_cmd = WORKER_CLOUD_INIT
- vm_image = {
- "exactVersion": "21.06.0",
- "offer": "ngc_base_image_version_b",
- "publisher": "nvidia",
- "sku": "gen2_21-06-0",
- "version": "latest",
- }
- vm_config = {
- "sku": {
- "tier": "Standard",
- "capacity": SWARM_SIZE,
- "name": "Standard_NC4as_T4_v3"
- },
- "plan": {
- "name": "gen2_21-06-0",
- "publisher": "nvidia",
- "product": "ngc_base_image_version_b"
- },
- "location": LOCATION,
- "virtual_machine_profile": {
- "storage_profile": {
- "image_reference": vm_image,
- "os_disk": {
- "caching": "ReadWrite",
- "managed_disk": {"storage_account_type": "Standard_LRS"},
- "create_option": "FromImage",
- "disk_size_gb": "32",
- },
- },
- "os_profile": {
- "computer_name_prefix": scaleset_name,
- "admin_username": "hivemind",
- "admin_password": ADMIN_PASS,
- "linux_configuration": {
- "disable_password_authentication": True,
- "ssh": {
- "public_keys": [
- {
- "key_data": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDPFugAsrqEsqxj+hKDTfgrtkY26jqCjRubT5vhnJLhtkDAqe5vJ1donWfUVhtBfnqGr92LPmJewPUd9hRa1i33FLVVdkFAs5/Cg8/YbzR8B8e1Y+Nl5HeT7Dq1i+cPEbA1EZAm9tqK4VWYeCMd3CDkoJVuweTwyja08mxtnVNwKCeY4oBKQCE5QlliAKaQnGpJE6MRnbudWM9Ly1wM6OaJVdGwsfPfEG/sSDip4q/8x/KGAzKbhE6ax15Yu/Bu12ahcIdScQsYK9Y6Sm57MHQQLWQO1G+3h3oCTXQ0BGaSMWKXsjmHsB7f9kLZ1j8yMoGlgbpWbjB0ZVsK/4Zh8Ho3h9gDXADzt1j69qT1aERWCt7fxp9+WOLsCTw1W/W9FY2Ia4niVh2/wEwT9AcOBcAqBl7kXQAoUpP8b2Xb+KNXyTEtVB562EdFn+LmG1gZAy8J3piy2/zoo16QJP5PjpKW5GFxL6BRYLtG+uxgx1Glya617T0dtJF/X2vxjT45QK3FaFH1Zd+vhpcLg94fOPNPEhNU7EeBVp8CGYNd+aXVIPsb0I7EIVu9wWi3/a7y86cUedal61fEigfmAQkC7AHYiAiiT94eARj0N+KgjEy2UOITSCJJTHuamYWO8jZc/n7yAqr6mxOKn5ZjBTfAR9bNB/D+HpL6yepI1UDGBVk4DQ== justHeuristic@gmail.com\n",
- "path": "/home/hivemind/.ssh/authorized_keys"
- }
- ]
- }
- },
- "custom_data": b64encode(cloud_init_cmd.encode('utf-8')).decode('latin-1'),
- },
- "network_profile": {
- "network_interface_configurations": [
- {
- "name": "test",
- "primary": True,
- "enable_accelerated_networking": True,
- "ip_configurations": [
- {
- "name": "test",
- "subnet": {
- "id": f"/subscriptions/{SUBSCRIPTION_ID}/resourceGroups/{GROUP_NAME}/providers/Microsoft.Network/virtualNetworks/{NETWORK_NAME}/subnets/{SUBNET_NAME}"
- },
- "public_ip_address_configuration": {
- "name": "pub1",
- "idle_timeout_in_minutes": 15
- }
- }
- ]
- }
- ]
- },
- "diagnostics_profile": {"boot_diagnostics": {"enabled": True}},
- "priority": "spot",
- "eviction_policy": "deallocate",
- },
- "upgrade_policy": {
- "mode": "Manual"
- },
- "upgrade_mode": "Manual",
- "spot_restore_policy": {"enabled": True}
- }
- # Create virtual machine scale set
- vmss = compute_client.virtual_machine_scale_sets.begin_create_or_update(
- GROUP_NAME,
- scaleset_name,
- vm_config,
- )
- print(f"{scaleset_name} {vmss.status()}")
- scalesets.append(vmss)
- for scaleset_name, vmss in zip(SCALE_SETS, scalesets):
- print(f"Created scale set {scaleset_name}:\n{vmss.result()}")
- else:
- delete_results = []
- for scaleset_name in SCALE_SETS:
- delete_results.append(compute_client.virtual_machine_scale_sets.begin_delete(GROUP_NAME, scaleset_name))
- for scaleset_name, delete_result in zip(SCALE_SETS, delete_results):
- delete_result.result()
- print(f"Deleted scale set {scaleset_name}")
- if __name__ == "__main__":
- main()
|