|
| 1 | +import json |
| 2 | +import re |
| 3 | +import time |
| 4 | +from datetime import datetime |
| 5 | + |
| 6 | +import click |
| 7 | +from lightning_cloud.openapi import ( |
| 8 | + V1AWSClusterDriverSpec, |
| 9 | + V1ClusterDriver, |
| 10 | + V1ClusterPerformanceProfile, |
| 11 | + V1ClusterSpec, |
| 12 | + V1CreateClusterRequest, |
| 13 | + V1InstanceSpec, |
| 14 | + V1KubernetesClusterDriver, |
| 15 | +) |
| 16 | +from lightning_cloud.openapi.models import Externalv1Cluster, V1ClusterState, V1ClusterType |
| 17 | +from rich.console import Console |
| 18 | +from rich.table import Table |
| 19 | +from rich.text import Text |
| 20 | + |
| 21 | +from lightning_app.cli.core import Formatable |
| 22 | +from lightning_app.utilities.network import LightningClient |
| 23 | +from lightning_app.utilities.openapi import create_openapi_object, string2dict |
| 24 | + |
| 25 | +CLUSTER_STATE_CHECKING_TIMEOUT = 60 |
| 26 | +MAX_CLUSTER_WAIT_TIME = 5400 |
| 27 | + |
| 28 | + |
| 29 | +class AWSClusterManager: |
| 30 | + """AWSClusterManager implements API calls specific to Lightning AI BYOC compute clusters when the AWS provider |
| 31 | + is selected as the backend compute.""" |
| 32 | + |
| 33 | + def __init__(self): |
| 34 | + self.api_client = LightningClient() |
| 35 | + |
| 36 | + def create( |
| 37 | + self, |
| 38 | + cost_savings: bool = False, |
| 39 | + cluster_name: str = None, |
| 40 | + role_arn: str = None, |
| 41 | + region: str = "us-east-1", |
| 42 | + external_id: str = None, |
| 43 | + instance_types: [str] = [], |
| 44 | + edit_before_creation: bool = False, |
| 45 | + wait: bool = False, |
| 46 | + ): |
| 47 | + """request Lightning AI BYOC compute cluster creation. |
| 48 | +
|
| 49 | + Args: |
| 50 | + cost_savings: Specifies if the cluster uses cost savings mode |
| 51 | + cluster_name: The name of the cluster to be created |
| 52 | + role_arn: AWS IAM Role ARN used to provision resources |
| 53 | + region: AWS region containing compute resources |
| 54 | + external_id: AWS IAM Role external ID |
| 55 | + instance_types: AWS instance types supported by the cluster |
| 56 | + edit_before_creation: Enables interactive editing of requests before submitting it to Lightning AI. |
| 57 | + wait: Waits for the cluster to be in a RUNNING state. Only use this for debugging. |
| 58 | + """ |
| 59 | + performance_profile = V1ClusterPerformanceProfile.DEFAULT |
| 60 | + if cost_savings: |
| 61 | + """In cost saving mode the number of compute nodes is reduced to one, reducing the cost for clusters |
| 62 | + with low utilization.""" |
| 63 | + performance_profile = V1ClusterPerformanceProfile.COST_SAVING |
| 64 | + |
| 65 | + body = V1CreateClusterRequest( |
| 66 | + name=cluster_name, |
| 67 | + spec=V1ClusterSpec( |
| 68 | + cluster_type=V1ClusterType.BYOC, |
| 69 | + performance_profile=performance_profile, |
| 70 | + driver=V1ClusterDriver( |
| 71 | + kubernetes=V1KubernetesClusterDriver( |
| 72 | + aws=V1AWSClusterDriverSpec( |
| 73 | + region=region, |
| 74 | + role_arn=role_arn, |
| 75 | + external_id=external_id, |
| 76 | + instance_types=[V1InstanceSpec(name=x) for x in instance_types], |
| 77 | + ) |
| 78 | + ) |
| 79 | + ), |
| 80 | + ), |
| 81 | + ) |
| 82 | + new_body = body |
| 83 | + if edit_before_creation: |
| 84 | + after = click.edit(json.dumps(body.to_dict(), indent=4)) |
| 85 | + if after is not None: |
| 86 | + new_body = create_openapi_object(string2dict(after), body) |
| 87 | + if new_body == body: |
| 88 | + click.echo("cluster unchanged") |
| 89 | + |
| 90 | + resp = self.api_client.cluster_service_create_cluster(body=new_body) |
| 91 | + if wait: |
| 92 | + _wait_for_cluster_state(self.api_client, resp.id, V1ClusterState.RUNNING) |
| 93 | + |
| 94 | + click.echo(f"${resp.id} cluster is ${resp.status.phase}") |
| 95 | + |
| 96 | + def list(self): |
| 97 | + resp = self.api_client.cluster_service_list_clusters(phase_not_in=[V1ClusterState.DELETED]) |
| 98 | + console = Console() |
| 99 | + console.print(ClusterList(resp.clusters).as_table()) |
| 100 | + |
| 101 | + def delete(self, cluster_id: str = None, force: bool = False, wait: bool = False): |
| 102 | + if force: |
| 103 | + click.echo( |
| 104 | + """ |
| 105 | + Deletes a BYOC cluster. Lightning AI removes cluster artifacts and any resources running on the cluster.\n |
| 106 | + WARNING: Deleting a cluster does not clean up any resources managed by Lightning AI.\n |
| 107 | + Check your cloud provider to verify that existing cloud resources are deleted. |
| 108 | + """ |
| 109 | + ) |
| 110 | + click.confirm("Do you want to continue?", abort=True) |
| 111 | + |
| 112 | + self.api_client.cluster_service_delete_cluster(id=cluster_id, force=force) |
| 113 | + click.echo("Cluster deletion triggered successfully") |
| 114 | + |
| 115 | + if wait: |
| 116 | + _wait_for_cluster_state(self.api_client, cluster_id, V1ClusterState.DELETED) |
| 117 | + |
| 118 | + |
| 119 | +class ClusterList(Formatable): |
| 120 | + def __init__(self, clusters: [Externalv1Cluster]): |
| 121 | + self.clusters = clusters |
| 122 | + |
| 123 | + def as_json(self) -> str: |
| 124 | + return json.dumps(self.clusters) |
| 125 | + |
| 126 | + def as_table(self) -> Table: |
| 127 | + table = Table("id", "name", "type", "status", "created", show_header=True, header_style="bold green") |
| 128 | + phases = { |
| 129 | + V1ClusterState.QUEUED: Text("queued", style="bold yellow"), |
| 130 | + V1ClusterState.PENDING: Text("pending", style="bold yellow"), |
| 131 | + V1ClusterState.RUNNING: Text("running", style="bold green"), |
| 132 | + V1ClusterState.FAILED: Text("failed", style="bold red"), |
| 133 | + V1ClusterState.DELETED: Text("deleted", style="bold red"), |
| 134 | + } |
| 135 | + |
| 136 | + cluster_type_lookup = { |
| 137 | + V1ClusterType.BYOC: Text("byoc", style="bold yellow"), |
| 138 | + V1ClusterType.GLOBAL: Text("lightning-cloud", style="bold green"), |
| 139 | + } |
| 140 | + for cluster in self.clusters: |
| 141 | + cluster: Externalv1Cluster |
| 142 | + status = phases[cluster.status.phase] |
| 143 | + if cluster.spec.desired_state == V1ClusterState.DELETED and cluster.status.phase != V1ClusterState.DELETED: |
| 144 | + status = Text("terminating", style="bold red") |
| 145 | + |
| 146 | + # this guard is necessary only until 0.3.93 releases which includes the `created_at` |
| 147 | + # field to the external API |
| 148 | + created_at = datetime.now() |
| 149 | + if hasattr(cluster, "created_at"): |
| 150 | + created_at = cluster.created_at |
| 151 | + |
| 152 | + table.add_row( |
| 153 | + cluster.id, |
| 154 | + cluster.name, |
| 155 | + cluster_type_lookup.get(cluster.spec.cluster_type, Text("unknown", style="red")), |
| 156 | + status, |
| 157 | + created_at.strftime("%Y-%m-%d") if created_at else "", |
| 158 | + ) |
| 159 | + return table |
| 160 | + |
| 161 | + |
| 162 | +def _wait_for_cluster_state( |
| 163 | + api_client: LightningClient, |
| 164 | + cluster_id: str, |
| 165 | + target_state: V1ClusterState, |
| 166 | + max_wait_time: int = MAX_CLUSTER_WAIT_TIME, |
| 167 | + check_timeout: int = CLUSTER_STATE_CHECKING_TIMEOUT, |
| 168 | +): |
| 169 | + """_wait_for_cluster_state waits until the provided cluster has reached a desired state, or failed. |
| 170 | +
|
| 171 | + Args: |
| 172 | + api_client: LightningClient used for polling |
| 173 | + cluster_id: Specifies the cluster to wait for |
| 174 | + target_state: Specifies the desired state the target cluster needs to meet |
| 175 | + max_wait_time: Maximum duration to wait (in seconds) |
| 176 | + check_timeout: duration between polling for the cluster state (in seconds) |
| 177 | + """ |
| 178 | + start = time.time() |
| 179 | + elapsed = 0 |
| 180 | + while elapsed < max_wait_time: |
| 181 | + cluster_resp = api_client.cluster_service_list_clusters() |
| 182 | + new_cluster = None |
| 183 | + for clust in cluster_resp.clusters: |
| 184 | + if clust.id == cluster_id: |
| 185 | + new_cluster = clust |
| 186 | + break |
| 187 | + if new_cluster is not None: |
| 188 | + if new_cluster.status.phase == target_state: |
| 189 | + break |
| 190 | + elif new_cluster.status.phase == V1ClusterState.FAILED: |
| 191 | + raise click.ClickException(f"Cluster {cluster_id} is in failed state.") |
| 192 | + time.sleep(check_timeout) |
| 193 | + elapsed = time.time() - start |
| 194 | + else: |
| 195 | + raise click.ClickException("Max wait time elapsed") |
| 196 | + |
| 197 | + |
| 198 | +def _check_cluster_name_is_valid(_ctx, _param, value): |
| 199 | + pattern = r"^(?!-)[a-z0-9-]{1,63}(?<!-)$" |
| 200 | + if not re.match(pattern, value): |
| 201 | + raise click.ClickException( |
| 202 | + """The cluster name is invalid. |
| 203 | + Cluster names can only contain lowercase letters, numbers, and periodic hyphens ( - ). |
| 204 | + Provide a cluster name using valid characters and try again.""" |
| 205 | + ) |
| 206 | + return value |
0 commit comments