Skip to content

Commit

Permalink
Fix Google Mypy Dataproc errors (#20570)
Browse files Browse the repository at this point in the history
Part of #19891
  • Loading branch information
potiuk authored Dec 30, 2021
1 parent a6e60ce commit bd9e8ce
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 82 deletions.
26 changes: 13 additions & 13 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
if properties is not None:
self.job["job"][job_type]["properties"] = properties

def add_labels(self, labels: dict) -> None:
def add_labels(self, labels: Optional[dict] = None) -> None:
"""
Set labels for Dataproc job.
Expand All @@ -79,17 +79,17 @@ def add_labels(self, labels: dict) -> None:
if labels:
self.job["job"]["labels"].update(labels)

def add_variables(self, variables: List[str]) -> None:
def add_variables(self, variables: Optional[Dict] = None) -> None:
"""
Set variables for Dataproc job.
:param variables: Variables for the job query.
:type variables: List[str]
:type variables: Dict
"""
if variables is not None:
self.job["job"][self.job_type]["script_variables"] = variables

def add_args(self, args: List[str]) -> None:
def add_args(self, args: Optional[List[str]] = None) -> None:
"""
Set args for Dataproc job.
Expand All @@ -99,12 +99,12 @@ def add_args(self, args: List[str]) -> None:
if args is not None:
self.job["job"][self.job_type]["args"] = args

def add_query(self, query: List[str]) -> None:
def add_query(self, query: str) -> None:
"""
Set query uris for Dataproc job.
Set query for Dataproc job.
:param query: URIs for the job queries.
:type query: List[str]
:param query: query for the job.
:type query: str
"""
self.job["job"][self.job_type]["query_list"] = {'queries': [query]}

Expand All @@ -117,7 +117,7 @@ def add_query_uri(self, query_uri: str) -> None:
"""
self.job["job"][self.job_type]["query_file_uri"] = query_uri

def add_jar_file_uris(self, jars: List[str]) -> None:
def add_jar_file_uris(self, jars: Optional[List[str]] = None) -> None:
"""
Set jars uris for Dataproc job.
Expand All @@ -127,7 +127,7 @@ def add_jar_file_uris(self, jars: List[str]) -> None:
if jars is not None:
self.job["job"][self.job_type]["jar_file_uris"] = jars

def add_archive_uris(self, archives: List[str]) -> None:
def add_archive_uris(self, archives: Optional[List[str]] = None) -> None:
"""
Set archives uris for Dataproc job.
Expand All @@ -137,7 +137,7 @@ def add_archive_uris(self, archives: List[str]) -> None:
if archives is not None:
self.job["job"][self.job_type]["archive_uris"] = archives

def add_file_uris(self, files: List[str]) -> None:
def add_file_uris(self, files: Optional[List[str]] = None) -> None:
"""
Set file uris for Dataproc job.
Expand All @@ -147,7 +147,7 @@ def add_file_uris(self, files: List[str]) -> None:
if files is not None:
self.job["job"][self.job_type]["file_uris"] = files

def add_python_file_uris(self, pyfiles: List[str]) -> None:
def add_python_file_uris(self, pyfiles: Optional[List[str]] = None) -> None:
"""
Set python file uris for Dataproc job.
Expand All @@ -157,7 +157,7 @@ def add_python_file_uris(self, pyfiles: List[str]) -> None:
if pyfiles is not None:
self.job["job"][self.job_type]["python_file_uris"] = pyfiles

def set_main(self, main_jar: Optional[str], main_class: Optional[str]) -> None:
def set_main(self, main_jar: Optional[str] = None, main_class: Optional[str] = None) -> None:
"""
Set Dataproc main class.
Expand Down
155 changes: 86 additions & 69 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,23 +1036,30 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
self.project_id = self.hook.project_id if project_id is None else project_id
self.job_template = None
self.job = None
self.job_template: Optional[DataProcJobBuilder] = None
self.job: Optional[dict] = None
self.dataproc_job_id = None
self.asynchronous = asynchronous

def create_job_template(self):
def create_job_template(self) -> DataProcJobBuilder:
"""Initialize `self.job_template` with default values"""
self.job_template = DataProcJobBuilder(
if self.project_id is None:
raise AirflowException(
"project id should either be set via project_id "
"parameter or retrieved from the connection,"
)
job_template = DataProcJobBuilder(
project_id=self.project_id,
task_id=self.task_id,
cluster_name=self.cluster_name,
job_type=self.job_type,
properties=self.dataproc_properties,
)
self.job_template.set_job_name(self.job_name)
self.job_template.add_jar_file_uris(self.dataproc_jars)
self.job_template.add_labels(self.labels)
job_template.set_job_name(self.job_name)
job_template.add_jar_file_uris(self.dataproc_jars)
job_template.add_labels(self.labels)
self.job_template = job_template
return job_template

def _generate_job_template(self) -> str:
if self.job_template:
Expand Down Expand Up @@ -1180,23 +1187,26 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
job_template = self.create_job_template()

if self.query is None:
self.job_template.add_query_uri(self.query_uri)
if self.query_uri is None:
raise AirflowException('One of query or query_uri should be set here')
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)
job_template.add_query(self.query)
job_template.add_variables(self.variables)
return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()

job_template = self.create_job_template()
if self.query is None:
self.job_template.add_query_uri(self.query_uri)
if self.query_uri is None:
raise AirflowException('One of query or query_uri should be set here')
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)
job_template.add_query(self.query)
job_template.add_variables(self.variables)

super().execute(context)

Expand Down Expand Up @@ -1256,22 +1266,25 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
job_template = self.create_job_template()
if self.query is None:
self.job_template.add_query_uri(self.query_uri)
if self.query_uri is None:
raise AirflowException('One of query or query_uri should be set here')
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)
job_template.add_query(self.query)
job_template.add_variables(self.variables)
return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()
job_template = self.create_job_template()
if self.query is None:
self.job_template.add_query_uri(self.query_uri)
if self.query_uri is None:
raise AirflowException('One of query or query_uri should be set here')
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)

job_template.add_query(self.query)
job_template.add_variables(self.variables)
super().execute(context)


Expand Down Expand Up @@ -1330,22 +1343,23 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
job_template = self.create_job_template()
if self.query is None:
self.job_template.add_query_uri(self.query_uri)
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)
job_template.add_query(self.query)
job_template.add_variables(self.variables)
return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()
job_template = self.create_job_template()
if self.query is None:
self.job_template.add_query_uri(self.query_uri)
if self.query_uri is None:
raise AirflowException('One of query or query_uri should be set here')
job_template.add_query_uri(self.query_uri)
else:
self.job_template.add_query(self.query)
self.job_template.add_variables(self.variables)

job_template.add_query(self.query)
job_template.add_variables(self.variables)
super().execute(context)


Expand Down Expand Up @@ -1411,20 +1425,19 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
self.job_template.set_main(self.main_jar, self.main_class)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)
job_template = self.create_job_template()
job_template.set_main(self.main_jar, self.main_class)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()
self.job_template.set_main(self.main_jar, self.main_class)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)

job_template = self.create_job_template()
job_template.set_main(self.main_jar, self.main_class)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
super().execute(context)


Expand Down Expand Up @@ -1490,20 +1503,19 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
self.job_template.set_main(self.main_jar, self.main_class)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)
job_template = self.create_job_template()
job_template.set_main(self.main_jar, self.main_class)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()
self.job_template.set_main(self.main_jar, self.main_class)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)

job_template = self.create_job_template()
job_template.set_main(self.main_jar, self.main_class)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
super().execute(context)


Expand Down Expand Up @@ -1594,24 +1606,24 @@ def generate_job(self):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
self.create_job_template()
job_template = self.create_job_template()
# Check if the file is local, if that is the case, upload it to a bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
bucket = cluster_info['config']['config_bucket']
self.main = f"gs://{bucket}/{self.main}"
self.job_template.set_python_main(self.main)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)
self.job_template.add_python_file_uris(self.pyfiles)
job_template.set_python_main(self.main)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
job_template.add_python_file_uris(self.pyfiles)

return self._generate_job_template()

def execute(self, context: 'Context'):
self.create_job_template()
job_template = self.create_job_template()
# Check if the file is local, if that is the case, upload it to a bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
Expand All @@ -1620,12 +1632,11 @@ def execute(self, context: 'Context'):
bucket = cluster_info['config']['config_bucket']
self.main = self._upload_file_temp(bucket, self.main)

self.job_template.set_python_main(self.main)
self.job_template.add_args(self.arguments)
self.job_template.add_archive_uris(self.archives)
self.job_template.add_file_uris(self.files)
self.job_template.add_python_file_uris(self.pyfiles)

job_template.set_python_main(self.main)
job_template.add_args(self.arguments)
job_template.add_archive_uris(self.archives)
job_template.add_file_uris(self.files)
job_template.add_python_file_uris(self.pyfiles)
super().execute(context)


Expand Down Expand Up @@ -2243,6 +2254,8 @@ def __init__(
def execute(self, context: 'Context'):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
self.log.info("Creating batch")
if self.region is None:
raise AirflowException('Region should be set here')
try:
self.operation = hook.create_batch(
region=self.region,
Expand All @@ -2254,10 +2267,14 @@ def execute(self, context: 'Context'):
timeout=self.timeout,
metadata=self.metadata,
)
if self.timeout is None:
raise AirflowException('Timeout should be set here')
result = hook.wait_for_operation(self.timeout, self.operation)
self.log.info("Batch %s created", self.batch_id)
except AlreadyExists:
self.log.info("Batch with given id already exists")
if self.batch_id is None:
raise AirflowException('Batch Id should be set here')
result = hook.get_batch(
batch_id=self.batch_id,
region=self.region,
Expand Down

0 comments on commit bd9e8ce

Please sign in to comment.
  翻译: