-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
42aff5b
commit 86926cc
Showing
10 changed files
with
153 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
name: pyspark | ||
on: push | ||
jobs: | ||
gcp: | ||
runs-on: ubuntu-latest | ||
environment: gcp | ||
steps: | ||
- uses: actions/checkout@main | ||
- uses: davidkhala/poetry-buildpack@main | ||
with: | ||
tests: py/tests | ||
test-entry-point: pytest | ||
working-directory: py | ||
env: | ||
PRIVATE_KEY: ${{ secrets.PRIVATE_KEY }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,3 @@ | ||
from pyspark.sql import SparkSession | ||
|
||
class Regular: | ||
""" | ||
Visit https://spark.apache.org/docs/latest/sql-getting-started.html#starting-point-sparksession for creating regular Spark Session | ||
""" | ||
|
||
@staticmethod | ||
def session(): | ||
return SparkSession.builder.getOrCreate() | ||
|
||
|
||
class SessionDecorator: | ||
spark: SparkSession | ||
|
||
def __init__(self, spark): | ||
self.spark: SparkSession = spark | ||
|
||
def disconnect(self): | ||
self.spark.stop() | ||
|
||
@property | ||
def schema(self) -> str: | ||
""" | ||
:return: current schema full name | ||
""" | ||
return self.spark.catalog.currentCatalog() + '.' + self.spark.catalog.currentDatabase() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from pyspark import SparkContext, SparkConf | ||
from datetime import datetime | ||
|
||
class Wrapper(SparkContext): | ||
sc:SparkContext | ||
def __init__(self, sc:SparkContext): | ||
self.sc = sc | ||
def __getattr__(self, name): | ||
# Delegate unknown attributes/methods to the wrapped instance | ||
return getattr(self.sc, name) | ||
def disconnect(self): | ||
self.sc.stop() | ||
|
||
@property | ||
def startTime(self): | ||
return datetime.fromtimestamp(self.sc.startTime/1000) | ||
@property | ||
def appTime(self)-> int: | ||
""" | ||
assume local spark app, not YARN | ||
:return: nanoseconds since unix epoch | ||
""" | ||
assert self.sc.applicationId.startswith('local-') | ||
epoch_nano = int(self.sc.applicationId[6:]) | ||
assert epoch_nano > self.sc.startTime | ||
return epoch_nano | ||
|
||
|
||
def getOrCreate(conf: SparkConf = SparkConf())->Wrapper: | ||
return Wrapper(SparkContext.getOrCreate(conf)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,20 @@ | ||
from _typeshed import DataclassInstance | ||
from dataclasses import dataclass, asdict | ||
from davidkhala.gcp.auth.options import ServiceAccountInfo | ||
from typing import TypedDict | ||
|
||
from davidkhala.spark import SessionDecorator | ||
|
||
|
||
@dataclass | ||
class AuthOptions(DataclassInstance): | ||
# TODO migrate to https://github.com/davidkhala/gcp-collections | ||
class AuthOptions(TypedDict): | ||
clientId: str | ||
clientEmail: str | ||
privateKey: str | ||
privateKeyId: str | ||
projectId: str | ||
|
||
def to_dict(self): | ||
return asdict(self) | ||
|
||
class Session(SessionDecorator): | ||
projectId: str | ||
def from_service_account(info: ServiceAccountInfo) -> AuthOptions: | ||
return AuthOptions( | ||
clientId=info.get('client_id'), | ||
clientEmail=info.get('client_email'), | ||
privateKey=info.get('private_key'), | ||
privateKeyId=info.get('private_key_id'), | ||
projectId=info.get('project_id'), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from pyspark import SparkConf | ||
from pyspark.errors import IllegalArgumentException | ||
from pyspark.sql import SparkSession | ||
|
||
class Wrapper: | ||
spark: SparkSession | ||
|
||
def __init__(self, spark): | ||
self.spark: SparkSession = spark | ||
|
||
def disconnect(self): | ||
self.spark.stop() | ||
|
||
@property | ||
def schema(self) -> str: | ||
""" | ||
:return: current schema full name | ||
""" | ||
return self.spark.catalog.currentCatalog() + '.' + self.spark.catalog.currentDatabase() | ||
|
||
def __getattr__(self, name): | ||
# Delegate unknown attributes/methods to the wrapped instance | ||
return getattr(self.spark, name) | ||
class Regular(Wrapper, SparkSession): | ||
@property | ||
def appName(self): | ||
try: | ||
return self.spark.conf.get("spark.app.name") | ||
except IllegalArgumentException as e: | ||
if str(e).splitlines()[0] == "The value of property spark.app.name must not be null": | ||
return | ||
else: | ||
raise e | ||
def regular(*, name: str = None, conf: SparkConf = SparkConf())->Regular: | ||
""" | ||
Visit https://spark.apache.org/docs/latest/sql-getting-started.html#starting-point-sparksession for creating regular Spark Session | ||
""" | ||
_ = SparkSession.builder.config(conf=conf) | ||
if name: _.appName(name) | ||
|
||
return Regular(_.getOrCreate()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import unittest | ||
|
||
from davidkhala.gcp.auth import OptionsInterface | ||
from davidkhala.gcp.auth.options import from_service_account | ||
|
||
from davidkhala.spark.gcp import AuthOptions | ||
|
||
|
||
class PubsubTestCase(unittest.TestCase): | ||
auth = AuthOptions( | ||
clientId=os.environ.get('CLIENT_ID'), | ||
privateKey=os.environ.get('PRIVATE_KEY'), | ||
clientEmail=os.environ.get('CLIENT_EMAIL'), | ||
privateKeyId=os.environ.get('PRIVATE_KEY_ID'), | ||
projectId=os.environ.get('PROJECT_ID'), | ||
) | ||
def test_auth(self): | ||
_ = from_service_account( | ||
client_email=self.auth.get('clientEmail'), | ||
private_key=self.auth.get('privateKey'), | ||
project_id=self.auth.get('projectId'), | ||
) | ||
OptionsInterface.token.fget(_) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,29 @@ | ||
import datetime | ||
import importlib | ||
import unittest | ||
|
||
from pyspark import SparkContext | ||
|
||
class CommonTestCase(unittest.TestCase): | ||
from spark.session import regular, Wrapper | ||
from spark.context import Wrapper as SCWrapper, getOrCreate | ||
from datetime import datetime | ||
class SyntaxTestCase(unittest.TestCase): | ||
def test_import(self): | ||
common = importlib.import_module('davidkhala.spark') | ||
from davidkhala.spark import Regular | ||
self.assertTrue(isinstance(common.Regular(), Regular)) | ||
|
||
|
||
self.assertTrue(isinstance(common.Decorator(), Wrapper)) | ||
|
||
def test_session(self): | ||
session = regular() | ||
self.assertTrue(isinstance(session.sparkContext, SparkContext)) | ||
self.assertEqual("pyspark-shell", session.appName) | ||
def test_context(self): | ||
sc = getOrCreate() | ||
self.assertLess(sc.startTime, datetime.now()) | ||
self.assertEqual('local[*]', sc.master) | ||
print(sc.defaultParallelism) | ||
self.assertEqual(2, sc.defaultMinPartitions) | ||
print(sc.appTime) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |