Skip to content

Commit

Permalink
Create GKEStartKueueJobOperator operator (#37477)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored Mar 18, 2024
1 parent c32d41d commit 80e60d7
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 1 deletion.
33 changes: 33 additions & 0 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,3 +1312,36 @@ def execute(self, context: Context):
cluster_hook=self.cluster_hook,
).fetch_cluster_info()
return super().execute(context)


class GKEStartKueueJobOperator(GKEStartJobOperator):
"""
Executes a Kubernetes Job in Kueue in the specified Google Kubernetes Engine cluster.
:param queue_name: The name of the Queue in the cluster
"""

def __init__(
self,
*,
queue_name: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.queue_name = queue_name

if self.suspend is False:
raise AirflowException(
"The `suspend` parameter can't be False. If you want to use Kueue for running Job"
" in a Kubernetes cluster, set the `suspend` parameter to True.",
)
elif self.suspend is None:
warnings.warn(
f"You have not set parameter `suspend` in class {self.__class__.__name__}. "
"For running a Job in Kueue the `suspend` parameter should set to True.",
UserWarning,
stacklevel=2,
)
self.suspend = True
self.labels.update({"kueue.x-k8s.io/queue-name": queue_name})
self.annotations.update({"kueue.x-k8s.io/queue-name": queue_name})
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ All Kubernetes parameters (except ``config_file``) are also valid for the ``GKES
:start-after: [START howto_operator_gke_start_job]
:end-before: [END howto_operator_gke_start_job]

For run Job on a GKE cluster with Kueue enabled use ``GKEStartKueueJobOperator``.

.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
:language: python
:dedent: 4
:start-after: [START howto_operator_kueue_start_job]
:end-before: [END howto_operator_kueue_start_job]


.. _howto/operator:GKEDescribeJobOperator:

Expand Down
121 changes: 121 additions & 0 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
GKEDescribeJobOperator,
GKEStartJobOperator,
GKEStartKueueInsideClusterOperator,
GKEStartKueueJobOperator,
GKEStartPodOperator,
)
from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger
Expand All @@ -51,6 +52,7 @@
PROJECT_LOCATION = "test-location"
PROJECT_TASK_ID = "test-task-id"
CLUSTER_NAME = "test-cluster-name"
QUEUE_NAME = "test-queue-name"

PROJECT_BODY = {"name": "test-name"}
PROJECT_BODY_CREATE_DICT = {"name": "test-name", "initial_node_count": 1}
Expand Down Expand Up @@ -1009,3 +1011,122 @@ def test_execute_with_impersonation_service_chain_one_element(
self.gke_op.execute(context=mock.MagicMock())

fetch_cluster_info_mock.assert_called_once()


class TestGKEStartKueueJobOperator:
def setup_method(self):
self.gke_op = GKEStartKueueJobOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
queue_name=QUEUE_NAME,
)
self.gke_op.job = mock.MagicMock(
name=TASK_NAME,
namespace=NAMESPACE,
)

def test_template_fields(self):
assert set(GKEStartJobOperator.template_fields).issubset(GKEStartKueueJobOperator.template_fields)

@mock.patch.dict(os.environ, {})
@mock.patch(KUB_JOB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
self.gke_op.execute(context=mock.MagicMock())
fetch_cluster_info_mock.assert_called_once()

def test_config_file_throws_error(self):
with pytest.raises(AirflowException):
GKEStartKueueJobOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
config_file="/path/to/alternative/kubeconfig",
)

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
)
@mock.patch(KUB_JOB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
def test_execute_with_impersonation_service_account(
self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
self.gke_op.impersonation_chain = "test_account@example.com"
self.gke_op.execute(context=mock.MagicMock())
fetch_cluster_info_mock.assert_called_once()

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
)
@mock.patch(KUB_JOB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
def test_execute_with_impersonation_service_chain_one_element(
self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
self.gke_op.impersonation_chain = ["test_account@example.com"]
self.gke_op.execute(context=mock.MagicMock())

fetch_cluster_info_mock.assert_called_once()

@pytest.mark.db_test
def test_default_gcp_conn_id(self):
gke_op = GKEStartKueueJobOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
queue_name=QUEUE_NAME,
)
gke_op._cluster_url = CLUSTER_URL
gke_op._ssl_ca_cert = SSL_CA_CERT
hook = gke_op.hook

assert hook.gcp_conn_id == "google_cloud_default"

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
def test_gcp_conn_id(self, get_con_mock):
gke_op = GKEStartKueueJobOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
gcp_conn_id="test_conn",
queue_name=QUEUE_NAME,
)
gke_op._cluster_url = CLUSTER_URL
gke_op._ssl_ca_cert = SSL_CA_CERT
hook = gke_op.hook

assert hook.gcp_conn_id == "test_conn"
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
import os
from datetime import datetime

from kubernetes.client import models as k8s

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.kubernetes_engine import (
GKECreateClusterOperator,
GKECreateCustomResourceOperator,
GKEDeleteClusterOperator,
GKEStartKueueInsideClusterOperator,
GKEStartKueueJobOperator,
)

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
Expand All @@ -38,6 +42,44 @@
CLUSTER_NAME = f"cluster-name-test-kueue-{ENV_ID}".replace("_", "-")
CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1, "autopilot": {"enabled": True}}

flavor_conf = """
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
name: default-flavor
"""
cluster_conf = """
apiVersion: kueue.x-k8s.io/v1beta1
kind: ClusterQueue
metadata:
name: cluster-queue
spec:
queueingStrategy: BestEffortFIFO
resourceGroups:
- coveredResources: ["cpu", "memory", "nvidia.com/gpu", "ephemeral-storage"]
flavors:
- name: "default-flavor"
resources:
- name: "cpu"
nominalQuota: 10
- name: "memory"
nominalQuota: 10Gi
- name: "nvidia.com/gpu"
nominalQuota: 10
- name: "ephemeral-storage"
nominalQuota: 10Gi
"""
QUEUE_NAME = "local-queue"
local_conf = f"""
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
namespace: default # LocalQueue under team-a namespace
name: {QUEUE_NAME}
spec:
clusterQueue: cluster-queue # Point to the ClusterQueue
"""

with DAG(
DAG_ID,
schedule="@once", # Override to match your needs
Expand All @@ -62,14 +104,70 @@
)
# [END howto_operator_gke_install_kueue]

create_resource_flavor = GKECreateCustomResourceOperator(
task_id="create_resource_flavor",
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
yaml_conf=flavor_conf,
custom_resource_definition=True,
namespaced=False,
)
create_cluster_queue = GKECreateCustomResourceOperator(
task_id="create_cluster_queue",
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
yaml_conf=cluster_conf,
custom_resource_definition=True,
namespaced=False,
)
create_local_queue = GKECreateCustomResourceOperator(
task_id="create_local_queue",
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
yaml_conf=local_conf,
custom_resource_definition=True,
)

# [START howto_operator_kueue_start_job]
kueue_job_task = GKEStartKueueJobOperator(
task_id="kueue_job_task",
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
queue_name=QUEUE_NAME,
namespace="default",
parallelism=3,
image="perl:5.34.0",
cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
name="test-pi",
container_resources=k8s.V1ResourceRequirements(
requests={
"cpu": 1,
"memory": "200Mi",
},
),
)
# [END howto_operator_kueue_start_job]

delete_cluster = GKEDeleteClusterOperator(
task_id="delete_cluster",
name=CLUSTER_NAME,
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
)

create_cluster >> add_kueue_cluster >> delete_cluster
(
create_cluster
>> add_kueue_cluster
>> create_resource_flavor
>> create_cluster_queue
>> create_local_queue
>> kueue_job_task
>> delete_cluster
)

from tests.system.utils.watcher import watcher

Expand Down

0 comments on commit 80e60d7

Please sign in to comment.
  翻译: