manage_scaleset.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import os
  2. from argparse import ArgumentParser
  3. from base64 import b64encode
  4. from azure.identity import DefaultAzureCredential
  5. from azure.mgmt.compute import ComputeManagementClient
  6. from azure.mgmt.network import NetworkManagementClient
  7. from azure.mgmt.resource import ResourceManagementClient
  8. print("=======================WARNING=======================")
  9. print("= The code may fail to import 'gi' but that is okay =")
  10. print("===================END OF WARNING====================")
  11. SUBSCRIPTION_ID = os.environ["SUBSCRIPTION_ID"]
  12. GROUP_NAME = "dalle_west2"
  13. NETWORK_NAME = "vnet"
  14. SUBNET_NAME = "subnet"
  15. LOCATION = "westus2"
  16. ADMIN_PASS = os.environ['AZURE_PASS']
  17. SCALE_SETS = ('worker',)
  18. SWARM_SIZE = 4
  19. WORKER_CLOUD_INIT = """#cloud-config
  20. package_update: true
  21. packages:
  22. - build-essential
  23. - wget
  24. - git
  25. - vim
  26. write_files:
  27. - path: /home/hivemind/init_worker.sh
  28. permissions: '0766'
  29. owner: root:root
  30. content: |
  31. #!/usr/bin/env bash
  32. set -e
  33. wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh
  34. bash install_miniconda.sh -b -p /opt/conda
  35. export PATH="/opt/conda/bin:${PATH}"
  36. conda install python~=3.8.0 pip
  37. conda install pytorch cudatoolkit=11.1 -c pytorch -c nvidia
  38. conda clean --all
  39. pip install https://github.com/learning-at-home/hivemind/archive/scaling_tweaks.zip
  40. systemctl enable testserv
  41. systemctl start testserv
  42. - path: /etc/systemd/system/testserv.service
  43. permissions: '0777'
  44. owner: root:root
  45. content: |
  46. [Unit]
  47. Description=One Shot
  48. [Service]
  49. ExecStart=/etc/createfile
  50. Type=oneshot
  51. RemainAfterExit=yes
  52. [Install]
  53. WantedBy=multi-user.target
  54. - path: /etc/createfile
  55. permissions: '0777'
  56. owner: root:root
  57. content: |
  58. #!/bin/bash
  59. export PATH="/opt/conda/bin:${PATH}"
  60. cd /home/hivemind
  61. ulimit -n 8192
  62. git clone https://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@github.com/learning-at-home/dalle-hivemind.git -b azure
  63. cd dalle-hivemind
  64. pip install -r requirements.txt
  65. pip install -U transformers==4.10.2 datasets==1.11.0
  66. WANDB_API_KEY=7cc938e45e63ef7d2f88f811be240ba0395c02dd python run_trainer.py --run_name $(hostname) \
  67. --experiment_prefix dalle_large_5groups \
  68. --initial_peers /ip4/52.232.13.142/tcp/31334/p2p/QmZLrSPKAcP4puJ8gUGvQ155thk5Q6J7oE5exMUSq1oD5i \
  69. --per_device_train_batch_size 1 --gradient_accumulation_steps 1
  70. runcmd:
  71. - bash /home/hivemind/init_worker.sh
  72. """
  73. def main():
  74. parser = ArgumentParser()
  75. parser.add_argument('command', choices=('create', 'delete'))
  76. args = parser.parse_args()
  77. resource_client = ResourceManagementClient(
  78. credential=DefaultAzureCredential(),
  79. subscription_id=SUBSCRIPTION_ID
  80. )
  81. network_client = NetworkManagementClient(
  82. credential=DefaultAzureCredential(),
  83. subscription_id=SUBSCRIPTION_ID
  84. )
  85. compute_client = ComputeManagementClient(
  86. credential=DefaultAzureCredential(),
  87. subscription_id=SUBSCRIPTION_ID
  88. )
  89. # Create resource group
  90. resource_client.resource_groups.create_or_update(
  91. GROUP_NAME,
  92. {"location": LOCATION}
  93. )
  94. # Create virtual network
  95. network_client.virtual_networks.begin_create_or_update(
  96. GROUP_NAME,
  97. NETWORK_NAME,
  98. {
  99. 'location': LOCATION,
  100. 'address_space': {
  101. 'address_prefixes': ['10.0.0.0/16']
  102. }
  103. }
  104. ).result()
  105. subnet = network_client.subnets.begin_create_or_update(
  106. GROUP_NAME,
  107. NETWORK_NAME,
  108. SUBNET_NAME,
  109. {'address_prefix': '10.0.0.0/16'}
  110. ).result()
  111. if args.command == 'create':
  112. scalesets = []
  113. for scaleset_name in SCALE_SETS:
  114. cloud_init_cmd = WORKER_CLOUD_INIT
  115. vm_image = {
  116. "exactVersion": "21.06.0",
  117. "offer": "ngc_base_image_version_b",
  118. "publisher": "nvidia",
  119. "sku": "gen2_21-06-0",
  120. "version": "latest",
  121. }
  122. vm_config = {
  123. "sku": {
  124. "tier": "Standard",
  125. "capacity": SWARM_SIZE,
  126. "name": "Standard_NC4as_T4_v3"
  127. },
  128. "plan": {
  129. "name": "gen2_21-06-0",
  130. "publisher": "nvidia",
  131. "product": "ngc_base_image_version_b"
  132. },
  133. "location": LOCATION,
  134. "virtual_machine_profile": {
  135. "storage_profile": {
  136. "image_reference": vm_image,
  137. "os_disk": {
  138. "caching": "ReadWrite",
  139. "managed_disk": {"storage_account_type": "Standard_LRS"},
  140. "create_option": "FromImage",
  141. "disk_size_gb": "32",
  142. },
  143. },
  144. "os_profile": {
  145. "computer_name_prefix": scaleset_name,
  146. "admin_username": "hivemind",
  147. "admin_password": ADMIN_PASS,
  148. "linux_configuration": {
  149. "disable_password_authentication": True,
  150. "ssh": {
  151. "public_keys": [
  152. {
  153. "key_data": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDPFugAsrqEsqxj+hKDTfgrtkY26jqCjRubT5vhnJLhtkDAqe5vJ1donWfUVhtBfnqGr92LPmJewPUd9hRa1i33FLVVdkFAs5/Cg8/YbzR8B8e1Y+Nl5HeT7Dq1i+cPEbA1EZAm9tqK4VWYeCMd3CDkoJVuweTwyja08mxtnVNwKCeY4oBKQCE5QlliAKaQnGpJE6MRnbudWM9Ly1wM6OaJVdGwsfPfEG/sSDip4q/8x/KGAzKbhE6ax15Yu/Bu12ahcIdScQsYK9Y6Sm57MHQQLWQO1G+3h3oCTXQ0BGaSMWKXsjmHsB7f9kLZ1j8yMoGlgbpWbjB0ZVsK/4Zh8Ho3h9gDXADzt1j69qT1aERWCt7fxp9+WOLsCTw1W/W9FY2Ia4niVh2/wEwT9AcOBcAqBl7kXQAoUpP8b2Xb+KNXyTEtVB562EdFn+LmG1gZAy8J3piy2/zoo16QJP5PjpKW5GFxL6BRYLtG+uxgx1Glya617T0dtJF/X2vxjT45QK3FaFH1Zd+vhpcLg94fOPNPEhNU7EeBVp8CGYNd+aXVIPsb0I7EIVu9wWi3/a7y86cUedal61fEigfmAQkC7AHYiAiiT94eARj0N+KgjEy2UOITSCJJTHuamYWO8jZc/n7yAqr6mxOKn5ZjBTfAR9bNB/D+HpL6yepI1UDGBVk4DQ== justHeuristic@gmail.com\n",
  154. "path": "/home/hivemind/.ssh/authorized_keys"
  155. }
  156. ]
  157. }
  158. },
  159. "custom_data": b64encode(cloud_init_cmd.encode('utf-8')).decode('latin-1'),
  160. },
  161. "network_profile": {
  162. "network_interface_configurations": [
  163. {
  164. "name": "test",
  165. "primary": True,
  166. "enable_accelerated_networking": True,
  167. "ip_configurations": [
  168. {
  169. "name": "test",
  170. "subnet": {
  171. "id": f"/subscriptions/{SUBSCRIPTION_ID}/resourceGroups/{GROUP_NAME}/providers/Microsoft.Network/virtualNetworks/{NETWORK_NAME}/subnets/{SUBNET_NAME}"
  172. },
  173. "public_ip_address_configuration": {
  174. "name": "pub1",
  175. "idle_timeout_in_minutes": 15
  176. }
  177. }
  178. ]
  179. }
  180. ]
  181. },
  182. "diagnostics_profile": {"boot_diagnostics": {"enabled": True}},
  183. "priority": "spot",
  184. "eviction_policy": "deallocate",
  185. },
  186. "upgrade_policy": {
  187. "mode": "Manual"
  188. },
  189. "upgrade_mode": "Manual",
  190. "spot_restore_policy": {"enabled": True}
  191. }
  192. # Create virtual machine scale set
  193. vmss = compute_client.virtual_machine_scale_sets.begin_create_or_update(
  194. GROUP_NAME,
  195. scaleset_name,
  196. vm_config,
  197. )
  198. print(f"{scaleset_name} {vmss.status()}")
  199. scalesets.append(vmss)
  200. for scaleset_name, vmss in zip(SCALE_SETS, scalesets):
  201. print(f"Created scale set {scaleset_name}:\n{vmss.result()}")
  202. else:
  203. delete_results = []
  204. for scaleset_name in SCALE_SETS:
  205. delete_results.append(compute_client.virtual_machine_scale_sets.begin_delete(GROUP_NAME, scaleset_name))
  206. for scaleset_name, delete_result in zip(SCALE_SETS, delete_results):
  207. delete_result.result()
  208. print(f"Deleted scale set {scaleset_name}")
  209. if __name__ == "__main__":
  210. main()