Pulling Redshift data from Sagemaker
A set of high-level function allowing you to pull data from Amazon Redshift into an Amazon Sagemaker Notebook instance, via Amazon S3.
Both botocore
and boto3
libraries are required to easily access AWS services from python.
- botocore: https://github.com/boto/botocore
- boto3: https://github.com/boto/boto3
import botocore as btc
import boto3 as b3
The CustomWaiter
is here to report async status from an AWS operation. Let's have a look at the waiters.py
file.
import botocore as btc
import boto3 as b3
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class WaitState(Enum):
SUCCESS = 'success'
FAILURE = 'failure'
class CustomWaiter:
"""
Base class for a custom waiter that leverages botocore's waiter code. Waiters
poll an operation, with a specified delay between each polling attempt, until
either an accepted result is returned or the number of maximum attempts is reached.
To use, implement a subclass that passes the specific operation, arguments,
and acceptors to the superclass.
For example, to implement a custom waiter for the transcription client that
waits for both success and failure outcomes of the get_transcription_job function,
create a class like the following:
class TranscribeCompleteWaiter(CustomWaiter):
def __init__(self, client):
super().__init__(
'TranscribeComplete', 'GetTranscriptionJob',
'TranscriptionJob.TranscriptionJobStatus',
{'COMPLETED': WaitState.SUCCESS, 'FAILED': WaitState.FAILURE},
client)
def wait(self, job_name):
self._wait(TranscriptionJobName=job_name)
"""
def __init__(
self, name, operation, argument, acceptors, client, delay=10, max_tries=60,
matcher='path'):
"""
Subclasses should pass specific operations, arguments, and acceptors to
their superclass.
:param name: The name of the waiter. This can be any descriptive string.
:param operation: The operation to wait for. This must match the casing of
the underlying operation model, which is typically in
CamelCase.
:param argument: The dict keys used to access the result of the operation, in
dot notation. For example, 'Job.Status' will access
result['Job']['Status'].
:param acceptors: The list of acceptors that indicate the wait is over. These
can indicate either success or failure. The acceptor values
are compared to the result of the operation after the
argument keys are applied.
:param client: The Boto3 client.
:param delay: The number of seconds to wait between each call to the operation.
:param max_tries: The maximum number of tries before exiting.
:param matcher: The kind of matcher to use.
"""
self.name = name
self.operation = operation
self.argument = argument
self.client = client
self.waiter_model = btc.waiter.WaiterModel({
'version': 2,
'waiters': {
name: {
"delay": delay,
"operation": operation,
"maxAttempts": max_tries,
"acceptors": [{
"state": state.value,
"matcher": matcher,
"argument": argument,
"expected": expected
} for expected, state in acceptors.items()]
}}})
self.waiter = btc.waiter.create_waiter_with_client(
self.name, self.waiter_model, self.client)
def __call__(self, parsed, **kwargs):
"""
Handles the after-call event by logging information about the operation and its
result.
:param parsed: The parsed response from polling the operation.
:param kwargs: Not used, but expected by the caller.
"""
status = parsed
for key in self.argument.split('.'):
if key.endswith('[]'):
status = status.get(key[:-2])[0]
else:
status = status.get(key)
logger.info(
"Waiter %s called %s, got %s.", self.name, self.operation, status)
def _wait(self, **kwargs):
"""
Registers for the after-call event and starts the botocore wait loop.
:param kwargs: Keyword arguments that are passed to the operation being polled.
"""
event_name = f'after-call.{self.client.meta.service_model.service_name}'
self.client.meta.events.register(event_name, self)
self.waiter.wait(**kwargs)
self.client.meta.events.unregister(event_name, self)
class ExecuteStatementWaiter(CustomWaiter):
def __init__(self, client):
super().__init__(
name='ExecuteStatementComplete',
operation='DescribeStatement',
argument='Status',
acceptors={
'COMPLETED': WaitState.SUCCESS,
'FAILED': WaitState.FAILURE,
'FINISHED': WaitState.SUCCESS,
'ABORTED': WaitState.FAILURE
},
client=client)
def wait(self, query_id):
self._wait(Id=query_id)
def get_waiter_config(name, delay, max_attempts):
return {
'version': 2,
'waiters':
{
name: {
'operation': 'ExecuteStatement',
'delay': delay,
'maxAttempts': max_attempts,
'acceptors': [
{
'matcher': 'path',
'expected': 'FINISHED',
'argument': 'Status',
'state': 'success'
},
{
'matcher': 'pathAny',
'expected': ['PICKED', 'STARTED', 'SUBMITTED'],
'argument': 'Status',
'state': 'retry'
},
{
'matcher': 'pathAny',
'expected': ['FAILED', 'ABORTED'],
'argument': 'Status',
'state': 'failure'
}
],
},
},
}
The function to be used is named redshift_query_to_pandas_via_s3
and will perform the following steps:
- Execute an
unload
query against aRedshift
database. Theunload
query will instructRedshift
to dump the data (into a specificS3
location) as a set of zippedcsv
files. - Load the previously created set of zipped
csv
files into apandas
dataframe.
The connection to Redshift
is secured by the ability to access a secret
in AWS Secrets Manager providing acess to the database (instead of using credentials). The permissions required to copy to and from S3
is driven by IAM
role based permissions.
import botocore as btc
import boto3 as b3
import base64
import json
import operator as op
import pandas as pd
import numpy as np
import datetime as dt
# That would be the "waiters.py" file
# containing the CustomWaiter class definition
import waiters as wt
def get_secret(secret_name, region_name):
secret = None
secret_arn = None
# Create a Secrets Manager client
session = b3.session.Session()
client = session.client(
service_name='secretsmanager',
region_name=region_name)
# In this sample we only handle the specific exceptions for the 'GetSecretValue' API.
# See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
# We rethrow the exception by default.
try:
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
secret_arn = get_secret_value_response['ARN']
except btc.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'DecryptionFailureException':
# Secrets Manager can't decrypt the protected secret text using the provided KMS key.
# Deal with the exception here, and/or rethrow at your discretion.
raise e
elif e.response['Error']['Code'] == 'InternalServiceErrorException':
# An error occurred on the server side.
# Deal with the exception here, and/or rethrow at your discretion.
raise e
elif e.response['Error']['Code'] == 'InvalidParameterException':
# You provided an invalid value for a parameter.
# Deal with the exception here, and/or rethrow at your discretion.
raise e
elif e.response['Error']['Code'] == 'InvalidRequestException':
# You provided a parameter value that is not valid for the current state of the resource.
# Deal with the exception here, and/or rethrow at your discretion.
raise e
elif e.response['Error']['Code'] == 'ResourceNotFoundException':
# We can't find the resource that you asked for.
# Deal with the exception here, and/or rethrow at your discretion.
raise e
else:
# Decrypts secret using the associated KMS CMK.
# Depending on whether the secret is a string or binary, one of these fields will be populated.
if 'SecretString' in get_secret_value_response:
secret = get_secret_value_response['SecretString']
else:
secret = base64.b64decode(get_secret_value_response['SecretBinary'])
return secret, secret_arn
def unpack_redshif_secret(secret):
return json.loads(secret)['dbClusterIdentifier']
def get_redshift_client(secret_name, region_name, database_name):
secret, secret_arn = get_secret(secret_name, region_name)
cluster_id = unpack_redshif_secret(secret)
return {
'client': b3.Session(
botocore_session=btc.session.get_session(),
region_name=region_name).client('redshift-data'),
'secret': secret,
'secret_arn': secret_arn,
'cluster_id': cluster_id,
'secret_name': secret_name,
'region_name': region_name,
'database_name': database_name
}
def get_unload_query(query, location, role):
return "unload('{0}') to '{1}' iam_role '{2}' format as CSV header ALLOWOVERWRITE GZIP;".format(
query.replace("'", "\\'"),
location,
role)
def execute_non_query(client, query, timeout):
custom_waiter = wt.ExecuteStatementWaiter(
client=client['client'])
res = client['client'].execute_statement(
Database=client['database_name'],
SecretArn=client['secret_arn'],
ClusterIdentifier=client['cluster_id'],
Sql=query)
print('query started')
id = res['Id']
try:
custom_waiter.wait(query_id=id)
print('query complete')
except btc.waiter.WaiterError as e:
print (e)
return None
except:
print (e)
return None
return res
def create_folder_in_bucket(buket_name, folder_name):
s3 = boto3.client('s3')
s3.put_object(
Bucket=bucket_name,
Key=(folder_name+'/'))
def get_dataframe_from_bucket(bucket_name, folder_name, prefix):
s3 = b3.resource('s3')
bucket = s3.Bucket(bucket_name)
path = '{0}/{1}'.format(
folder_name,
prefix)
df_main = None
for file in bucket.objects.filter(Prefix=path):
df = pd.read_csv('s3://{0}/{1}'.format(
file.bucket_name,
file.key))
if df_main is None:
df_main = df
else:
df_main = pd.concat([df_main, df])
return df_main
def redshift_query_to_pandas_via_s3(
secret_name,
region_name,
database_name,
query,
bucket_name,
folder_name,
prefix,
role):
# Create a Redshift API client
redshift = get_redshift_client(
secret_name=secret_name,
region_name=region_name,
database_name=database_name)
# Compute the UNLOAD query
unload_query = get_unload_query(
query=query,
location='s3://{0}/{1}/{2}'.format(
bucket_name,
folder_name,
prefix),
role=role)
# Execute the query.
# Note that the `timeout` value could well be parameterized.
res = execute_non_query(
client=redshift,
query=unload_query,
timeout=1200)
return get_dataframe_from_bucket(
bucket_name=bucket_name,
folder_name=folder_name,
prefix=prefix)