In [None]:
from pyspark.sql import functions as F

Get dataframe from data catalog

In [None]:
df = spark.sql("select * from pmt_db.financial")

In [None]:
df.printSchema()
df.show(10)

Write to S3 bucket

In [None]:
df.limit(10).write \
    .mode("overwrite") \
    .format("csv") \
    .option("header", "true") \
    .save('s3://athena-spark-workshop/output/')


Analyze data from Government

In [None]:
df_gov = (
    df.filter(df.segment == "Government")
    .withColumn("country", F.lower("country"))
    .select("country", "product", "sales", "month_number", "year").orderBy("country")
)

In [None]:
df_gov.limit(10).show()

In [None]:
df_gov_agg = (
    df_gov.groupby("country", "product", "year").agg(
        F.sum("sales").cast("decimal(15,2)").alias("total_sales"),
        F.avg("sales").cast("decimal(15,2)").alias("avg_sales"),
        F.max("sales").cast("decimal(15,2)").alias("max_sales"),
        F.min("sales").cast("decimal(15,2)").alias("min_sales"),
    )
)

In [None]:
df_gov_agg.limit(10).show()

# Create table in glue data catalog so we can also query data using Athena Query Editor.

# NOTE: Remember to load partitions in Athena Query Editor.


In [None]:
%%sql
create table if not exists default.gov(
          country string ,
          product string ,
          total_sales decimal(15,2),
          avg_sales decimal(15,2),
          max_sales decimal(15,2),
          min_sales decimal(15,2))
    partitioned by (year bigint)
    ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
    location 's3://athena-spark-workshop/output/gov'


# Get started with building visualization using Amazon Athena for Apache Spark

#### Use Seaborn to build visualization on this data


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

pd_canada = (
    df_gov_agg.filter(df_gov_agg.country == "canada")
    .filter(df_gov_agg.year == "2014")
    .toPandas()
)
print(pd_canada)
res = sns.relplot(x='product', y='total_sales', data=pd_canada, kind="line")


%matplot plt

#### Build visualization using Matplotlib

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.clf()
name = pd_canada['product'].head(12)
price = pd_canada['total_sales'].head(12)
 
# Figure Size
fig = plt.figure(figsize =(10, 7))
 
# Horizontal Bar Plot
plt.bar(name[0:10], price[0:10])
 
# Show Plot
%matplot plt

#### adding piglatin python library we created earlier to PySpark session. 

In [None]:
sc.addPyFile('s3://athena-spark-workshop-227972251644/library.zip') #use your path

import piglatin
piglatin.translate('hello')

from pyspark.sql.functions import udf
from pyspark.sql.functions import col

hi_udf = udf(piglatin.translate)

df = spark.createDataFrame([(1, "hello"), (2, "world")])

df.withColumn("col", hi_udf(col('_2'))).show()