manage_scaleset.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. from argparse import ArgumentParser
  3. from base64 import b64encode
  4. from azure.identity import DefaultAzureCredential
  5. from azure.mgmt.compute import ComputeManagementClient
  6. from azure.mgmt.network import NetworkManagementClient
  7. from azure.mgmt.resource import ResourceManagementClient
  8. SUBSCRIPTION_ID = os.environ["SUBSCRIPTION_ID"]
  9. GROUP_NAME = "dalle_northeu"
  10. NETWORK_NAME = "vnet"
  11. SUBNET_NAME = "subnet"
  12. LOCATION = "northeurope"
  13. ADMIN_PASS = os.environ['AZURE_PASS']
  14. SCALE_SETS = ('worker',)
  15. SWARM_SIZE = 16
  16. WORKER_CLOUD_INIT = """#cloud-config
  17. package_update: true
  18. packages:
  19. - build-essential
  20. - wget
  21. - git
  22. - vim
  23. write_files:
  24. - path: /home/hivemind/init_worker.sh
  25. permissions: '0766'
  26. owner: root:root
  27. content: |
  28. #!/usr/bin/env bash
  29. set -e
  30. wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh
  31. bash install_miniconda.sh -b -p /opt/conda
  32. export PATH="/opt/conda/bin:${PATH}"
  33. conda install python~=3.8.0 pip
  34. conda install pytorch cudatoolkit=11.1 -c pytorch -c nvidia
  35. conda clean --all
  36. pip install https://github.com/learning-at-home/hivemind/archive/c328698b8668e6c548a571c175d045ac7df08586.zip
  37. systemctl enable testserv
  38. systemctl start testserv
  39. - path: /etc/systemd/system/testserv.service
  40. permissions: '0777'
  41. owner: root:root
  42. content: |
  43. [Unit]
  44. Description=One Shot
  45. [Service]
  46. ExecStart=/etc/createfile
  47. Type=oneshot
  48. RemainAfterExit=yes
  49. [Install]
  50. WantedBy=multi-user.target
  51. - path: /etc/createfile
  52. permissions: '0777'
  53. owner: root:root
  54. content: |
  55. #!/bin/bash
  56. export PATH="/opt/conda/bin:${PATH}"
  57. cd /home/hivemind
  58. ulimit -n 8192
  59. git clone https://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@github.com/learning-at-home/dalle-hivemind.git -b main
  60. cd dalle-hivemind
  61. pip install -r requirements.txt
  62. pip install -U transformers==4.10.2 datasets==1.11.0
  63. WANDB_API_KEY=7cc938e45e63ef7d2f88f811be240ba0395c02dd python run_trainer.py --run_name $(hostname) \
  64. --experiment_prefix dalle_large_5groups \
  65. --initial_peers /ip4/172.16.1.66/tcp/31334/p2p/QmZLrSPKAcP4puJ8gUGvQ155thk5Q6J7oE5exMUSq1oD5i \
  66. --per_device_train_batch_size 4 --gradient_accumulation_steps 2 --fp16
  67. runcmd:
  68. - bash /home/hivemind/init_worker.sh
  69. """
  70. def main():
  71. parser = ArgumentParser()
  72. parser.add_argument('command', choices=('create', 'delete'))
  73. args = parser.parse_args()
  74. resource_client = ResourceManagementClient(
  75. credential=DefaultAzureCredential(),
  76. subscription_id=SUBSCRIPTION_ID
  77. )
  78. network_client = NetworkManagementClient(
  79. credential=DefaultAzureCredential(),
  80. subscription_id=SUBSCRIPTION_ID
  81. )
  82. compute_client = ComputeManagementClient(
  83. credential=DefaultAzureCredential(),
  84. subscription_id=SUBSCRIPTION_ID
  85. )
  86. # Create resource group
  87. resource_client.resource_groups.create_or_update(
  88. GROUP_NAME,
  89. {"location": LOCATION}
  90. )
  91. # Create virtual network
  92. network_client.virtual_networks.begin_create_or_update(
  93. GROUP_NAME,
  94. NETWORK_NAME,
  95. {
  96. 'location': LOCATION,
  97. 'address_space': {
  98. 'address_prefixes': ['10.0.0.0/16']
  99. }
  100. }
  101. ).result()
  102. subnet = network_client.subnets.begin_create_or_update(
  103. GROUP_NAME,
  104. NETWORK_NAME,
  105. SUBNET_NAME,
  106. {'address_prefix': '10.0.0.0/16'}
  107. ).result()
  108. if args.command == 'create':
  109. scalesets = []
  110. for scaleset_name in SCALE_SETS:
  111. cloud_init_cmd = WORKER_CLOUD_INIT
  112. vm_image = {
  113. "exactVersion": "21.06.0",
  114. "offer": "ngc_base_image_version_b",
  115. "publisher": "nvidia",
  116. "sku": "gen2_21-06-0",
  117. "version": "latest",
  118. }
  119. vm_config = {
  120. "sku": {
  121. "tier": "Standard",
  122. "capacity": SWARM_SIZE,
  123. "name": "Standard_NC4as_T4_v3"
  124. },
  125. "plan": {
  126. "name": "gen2_21-06-0",
  127. "publisher": "nvidia",
  128. "product": "ngc_base_image_version_b"
  129. },
  130. "location": LOCATION,
  131. "virtual_machine_profile": {
  132. "storage_profile": {
  133. "image_reference": vm_image,
  134. "os_disk": {
  135. "caching": "ReadWrite",
  136. "managed_disk": {"storage_account_type": "Standard_LRS"},
  137. "create_option": "FromImage",
  138. "disk_size_gb": "32",
  139. },
  140. },
  141. "os_profile": {
  142. "computer_name_prefix": scaleset_name,
  143. "admin_username": "hivemind",
  144. "admin_password": ADMIN_PASS,
  145. "linux_configuration": {
  146. "disable_password_authentication": True,
  147. "ssh": {
  148. "public_keys": [
  149. {
  150. "key_data": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDPFugAsrqEsqxj+hKDTfgrtkY26jqCjRubT5vhnJLhtkDAqe5vJ1donWfUVhtBfnqGr92LPmJewPUd9hRa1i33FLVVdkFAs5/Cg8/YbzR8B8e1Y+Nl5HeT7Dq1i+cPEbA1EZAm9tqK4VWYeCMd3CDkoJVuweTwyja08mxtnVNwKCeY4oBKQCE5QlliAKaQnGpJE6MRnbudWM9Ly1wM6OaJVdGwsfPfEG/sSDip4q/8x/KGAzKbhE6ax15Yu/Bu12ahcIdScQsYK9Y6Sm57MHQQLWQO1G+3h3oCTXQ0BGaSMWKXsjmHsB7f9kLZ1j8yMoGlgbpWbjB0ZVsK/4Zh8Ho3h9gDXADzt1j69qT1aERWCt7fxp9+WOLsCTw1W/W9FY2Ia4niVh2/wEwT9AcOBcAqBl7kXQAoUpP8b2Xb+KNXyTEtVB562EdFn+LmG1gZAy8J3piy2/zoo16QJP5PjpKW5GFxL6BRYLtG+uxgx1Glya617T0dtJF/X2vxjT45QK3FaFH1Zd+vhpcLg94fOPNPEhNU7EeBVp8CGYNd+aXVIPsb0I7EIVu9wWi3/a7y86cUedal61fEigfmAQkC7AHYiAiiT94eARj0N+KgjEy2UOITSCJJTHuamYWO8jZc/n7yAqr6mxOKn5ZjBTfAR9bNB/D+HpL6yepI1UDGBVk4DQ== justHeuristic@gmail.com\n",
  151. "path": "/home/hivemind/.ssh/authorized_keys"
  152. }
  153. ]
  154. }
  155. },
  156. "custom_data": b64encode(cloud_init_cmd.encode('utf-8')).decode('latin-1'),
  157. },
  158. "network_profile": {
  159. "network_interface_configurations": [
  160. {
  161. "name": "test",
  162. "primary": True,
  163. "enable_accelerated_networking": True,
  164. "ip_configurations": [
  165. {
  166. "name": "test",
  167. "subnet": {
  168. "id": f"/subscriptions/{SUBSCRIPTION_ID}/resourceGroups/{GROUP_NAME}/providers/Microsoft.Network/virtualNetworks/{NETWORK_NAME}/subnets/{SUBNET_NAME}"
  169. },
  170. "public_ip_address_configuration": {
  171. "name": "pub1",
  172. "idle_timeout_in_minutes": 15
  173. }
  174. }
  175. ]
  176. }
  177. ]
  178. },
  179. "diagnostics_profile": {"boot_diagnostics": {"enabled": True}},
  180. "priority": "spot",
  181. "eviction_policy": "deallocate",
  182. },
  183. "upgrade_policy": {
  184. "mode": "Manual"
  185. },
  186. "upgrade_mode": "Manual",
  187. "spot_restore_policy": {"enabled": True}
  188. }
  189. # Create virtual machine scale set
  190. vmss = compute_client.virtual_machine_scale_sets.begin_create_or_update(
  191. GROUP_NAME,
  192. scaleset_name,
  193. vm_config,
  194. )
  195. print(f"{scaleset_name} {vmss.status()}")
  196. scalesets.append(vmss)
  197. for scaleset_name, vmss in zip(SCALE_SETS, scalesets):
  198. print(f"Created scale set {scaleset_name}:\n{vmss.result()}")
  199. else:
  200. delete_results = []
  201. for scaleset_name in SCALE_SETS:
  202. delete_results.append(compute_client.virtual_machine_scale_sets.begin_delete(GROUP_NAME, scaleset_name))
  203. for scaleset_name, delete_result in zip(SCALE_SETS, delete_results):
  204. delete_result.result()
  205. print(f"Deleted scale set {scaleset_name}")
  206. if __name__ == "__main__":
  207. main()