Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Fix error when /tmp is a symlink #221

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 108 additions & 75 deletions greenplumpython/experimental/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import inspect
import io
import pathlib
import sys
import tarfile
import tempfile
import uuid
from typing import get_type_hints
from typing import Any, get_type_hints

import psycopg2
import psycopg2.extensions

import greenplumpython as gp
from greenplumpython.func import NormalFunction
Expand All @@ -15,20 +18,30 @@


@gp.create_function
def _dump_file_chunk(tmp_archive_name: str, chunk_base64: str) -> int:
tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name
tmp_archive_base.mkdir(parents=True, exist_ok=True)
tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz"
def _dump_file_chunk(tmp_dir_handle: str, chunk_base64: str) -> str:
try:
_gd = globals()["GD"] # type: ignore reportUnknownVariableType
except KeyError:
_gd = sys.modules["plpy"]._GD
if tmp_dir_handle not in _gd:
server_tmp_dir = tempfile.TemporaryDirectory(prefix="pygp.srv.")
_gd[tmp_dir_handle] = server_tmp_dir # Pin to GD for later UDFs
else:
server_tmp_dir = _gd[tmp_dir_handle] # type: ignore reportUnknownVariableType

server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir.name) # type: ignore reportUnknownVariableType
server_tmp_dir_path.mkdir(parents=True, exist_ok=True)
tmp_archive_path = server_tmp_dir_path / f"{tmp_dir_handle}.tar.gz"
with open(tmp_archive_path, "ab") as tmp_archive:
tmp_archive.write(base64.b64decode(chunk_base64))
return 0
return server_tmp_dir.name


@gp.create_function
def _extract_files(tmp_archive_name: str, returning: str) -> list[str]:
tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name
tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz"
extracted_root = tmp_archive_base / "extracted"
def _extract_files(server_tmp_dir: str, tmp_dir_handle: str, returning: str) -> list[str]:
server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir)
tmp_archive_path = server_tmp_dir_path / f"{tmp_dir_handle}.tar.gz"
extracted_root = server_tmp_dir_path / "extracted"
if not extracted_root.exists():
with tarfile.open(tmp_archive_path, "r:gz") as tmp_archive:
extracted_root.mkdir()
Expand All @@ -43,63 +56,82 @@ def _extract_files(tmp_archive_name: str, returning: str) -> list[str]:
yield str(path.resolve())


def _archive_and_upload(tmp_archive_name: str, files: list[str], db: gp.Database):
tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name
tmp_archive_base.mkdir(exist_ok=True)
tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz"
with tarfile.open(tmp_archive_path, "w:gz") as tmp_archive:
for file_path in files:
tmp_archive.add(pathlib.Path(file_path))
server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None
with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType
with util_conn.cursor() as cursor: # type: ignore reportUnknownVariableType
cursor.execute(f"CREATE TEMP TABLE {tmp_archive_name} (id serial, text_base64 text);")
def _remove_tmp_dir(conn: psycopg2.extensions.connection, db: gp.Database, tmp_dir_handle: str):
@gp.create_function
def udf(tmp_dir_handle: str) -> None:
try:
_gd = globals()["GD"] # type: ignore reportUnknownVariableType
except KeyError:
_gd = sys.modules["plpy"]._GD
_gd[tmp_dir_handle].cleanup()

with conn.cursor() as cursor:
cursor.execute(udf._serialize(db))
cursor.execute(f"SELECT {udf._qualified_name_str}('{tmp_dir_handle}');")


def _archive_and_upload(
util_conn: psycopg2.extensions.connection,
tmp_dir_handle: str,
files: list[str],
db: gp.Database,
) -> str:
with tempfile.TemporaryDirectory(prefix="pygp.cln.") as local_tmp_dir:
local_tmp_dir_path: pathlib.Path = pathlib.Path(local_tmp_dir)
tmp_archive_path = local_tmp_dir_path / f"{tmp_dir_handle}.tar.gz"
with tarfile.open(tmp_archive_path, "w:gz") as tmp_archive:
for file_path in files:
tmp_archive.add(pathlib.Path(file_path))
with util_conn.cursor() as cursor:
cursor.execute(f"CREATE TEMP TABLE {tmp_dir_handle} (id serial, text_base64 text);")
with open(tmp_archive_path, "rb") as tmp_archive:
while True:
chunk = tmp_archive.read(_CHUNK_SIZE)
if len(chunk) == 0:
break
chunk_base64 = base64.b64encode(chunk)
cursor.copy_expert(
f"COPY {tmp_archive_name} (text_base64) FROM STDIN",
f"COPY {tmp_dir_handle} (text_base64) FROM STDIN",
io.BytesIO(chunk_base64),
)
util_conn.commit()
cursor.execute(_dump_file_chunk._serialize(db)) # type: ignore reportUnknownArgumentType
cursor.execute(_dump_file_chunk._serialize(db))
cursor.execute(
f"""
SELECT {_dump_file_chunk._qualified_name_str}('{tmp_archive_name}', text_base64)
FROM "{tmp_archive_name}"
SELECT {_dump_file_chunk._qualified_name_str}('{tmp_dir_handle}', text_base64)
FROM "{tmp_dir_handle}"
ORDER BY id;
"""
)
return cursor.fetchall()[0][0]


@classmethod
def _from_files(_, files: list[str], parser: NormalFunction, db: gp.Database) -> gp.DataFrame:
tmp_archive_name = f"tar_{uuid.uuid4().hex}"
_archive_and_upload(tmp_archive_name, files, db)
func_sig = inspect.signature(parser.unwrap())
result_members = get_type_hints(func_sig.return_annotation)
return db.apply(
lambda: parser(_extract_files(tmp_archive_name, "files")),
expand=len(result_members) == 0,
)
tmp_dir_handle = f"__pygp_tar_{uuid.uuid4().hex}"
server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None
with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType
server_tmp_dir = _archive_and_upload(util_conn, tmp_dir_handle, files, db) # type: ignore reportUnknownArgumentType
func_sig = inspect.signature(parser.unwrap())
result_members = get_type_hints(func_sig.return_annotation)
df = db.apply(
lambda: parser(_extract_files(server_tmp_dir, tmp_dir_handle, "files")),
expand=len(result_members) == 0,
)
# _remove_tmp_dir(util_conn, db, tmp_dir_handle) # Cannot remove now since the returned DataFrame depends on it.
return df


setattr(gp.DataFrame, "from_files", _from_files)


import subprocess as sp
import sys
import subprocess


@gp.create_function
def _install_on_server(pkg_dir: str, requirements: str) -> str:
import subprocess as sp
import sys

def _install_on_server(server_tmp_dir: str, local_tmp_dir: str, requirements: str) -> str:
assert sys.executable, "Python executable is required to install packages."
server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir)
local_tmp_dir_path = pathlib.Path(local_tmp_dir)
cmd = [
sys.executable,
"-m",
Expand All @@ -109,48 +141,49 @@ def _install_on_server(pkg_dir: str, requirements: str) -> str:
"--requirement",
"/dev/stdin",
"--find-links",
pkg_dir,
str(
server_tmp_dir_path
/ "extracted"
/ local_tmp_dir_path.relative_to(local_tmp_dir_path.root)
),
]
try:
output = sp.check_output(cmd, text=True, stderr=sp.STDOUT, input=requirements)
output = subprocess.check_output(
cmd, text=True, stderr=subprocess.STDOUT, input=requirements
)
return output
except sp.CalledProcessError as e:
except subprocess.CalledProcessError as e:
raise Exception(e.stdout)


def _install_packages(db: gp.Database, requirements: str):
tmp_archive_name = f"tar_{uuid.uuid4().hex}"
# FIXME: Windows client is not supported yet.
local_dir = pathlib.Path("/") / "tmp" / tmp_archive_name / "pip"
local_dir.mkdir(parents=True)
cmd = [
sys.executable,
"-m",
"pip",
"download",
"--requirement",
"/dev/stdin",
"--dest",
local_dir,
]
try:
sp.check_output(cmd, text=True, stderr=sp.STDOUT, input=requirements)
except sp.CalledProcessError as e:
raise e from Exception(e.stdout)
_archive_and_upload(tmp_archive_name, [local_dir.resolve()], db)
extracted = db.apply(lambda: _extract_files(tmp_archive_name, "root"), column_name="cache_dir")
assert len(list(extracted)) == 1
server_dir = (
pathlib.Path("/")
/ "tmp"
/ tmp_archive_name
/ "extracted"
/ local_dir.relative_to(local_dir.root)
)
installed = extracted.apply(
lambda _: _install_on_server(server_dir.as_uri(), requirements), column_name="result"
)
assert len(list(installed)) == 1
tmp_dir_handle = f"__pygp_tar_{uuid.uuid4().hex}"
with tempfile.TemporaryDirectory(prefix="pygp.cln.") as local_pkg_dir:
local_tmp_dir_path = pathlib.Path(local_pkg_dir)
cmd = [
sys.executable,
"-m",
"pip",
"download",
"--requirement",
"/dev/stdin",
"--dest",
str(local_tmp_dir_path),
]
try:
subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT, input=requirements)
except subprocess.CalledProcessError as e:
raise e from Exception(e.stdout)
server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None
with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType
server_tmp_dir = _archive_and_upload(util_conn, tmp_dir_handle, [local_pkg_dir], db) # type: ignore reportUnknownArgumentType
extracted = db.apply(lambda: _extract_files(server_tmp_dir, tmp_dir_handle, "root"))
assert len(list(extracted)) == 1
installed = extracted.apply(
lambda _: _install_on_server(server_tmp_dir, local_pkg_dir, requirements)
)
assert len(list(installed)) == 1
_remove_tmp_dir(util_conn, db, tmp_dir_handle) # type: ignore reportUnknownArgumentType


setattr(gp.Database, "install_packages", _install_packages)
1 change: 1 addition & 0 deletions greenplumpython/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def _serialize(self, db: Database) -> str:
f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n"
f" raise ModuleNotFoundError\n"
f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n"
f" setattr({sys_lib_name}.modules['plpy'], '_GD', GD)\n"
f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n"
f" except ModuleNotFoundError:\n"
f" exec({json.dumps(ast.unparse(func_ast))}, globals())\n"
Expand Down
6 changes: 4 additions & 2 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def test_csv_multi_chunks(db: gp.Database):
default_chunk_size = greenplumpython.experimental.file._CHUNK_SIZE
greenplumpython.experimental.file._CHUNK_SIZE = 3
assert greenplumpython.experimental.file._CHUNK_SIZE == 3
test_csv_single_chunk(db)
greenplumpython.experimental.file._CHUNK_SIZE = default_chunk_size
try:
test_csv_single_chunk(db)
finally:
greenplumpython.experimental.file._CHUNK_SIZE = default_chunk_size


import subprocess
Expand Down