Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1888076: built-in function support wave1 #2928

Merged
merged 12 commits into from
Jan 28, 2025
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,22 @@
- `regr_sxy`
- `regr_syy`
- `try_to_binary`
- `base64`
- `base64_decode_string`
- `base64_encode`
- `editdistance`
- `hex`
- `hex_encode`
- `instr`
- `levenshtein`
- `log1p`
- `log2`
- `log10`
- `percentile_approx`
- `unbase64`
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
- Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode.
- Added support for multiple columns in the functions `map_cat` and `map_concat`.

#### Experimental Features

Expand Down
13 changes: 13 additions & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ Functions
atanh
atan2
avg
base64
base64_decode_string
base64_encode
bit_length
bitmap_bit_position
bitmap_bucket_number
Expand Down Expand Up @@ -157,6 +160,7 @@ Functions
desc_nulls_last
div0
divnull
editdistance
endswith
equal_nan
equal_null
Expand All @@ -178,12 +182,15 @@ Functions
grouping
grouping_id
hash
hex
hex_encode
hour
iff
ifnull
in_
initcap
insert
instr
is_array
is_binary
is_boolean
Expand Down Expand Up @@ -211,12 +218,16 @@ Functions
least
left
length
levenshtein
listagg
lit
ln
locate
localtimestamp
log
log1p
log2
log10
lower
lpad
ltrim
Expand Down Expand Up @@ -257,6 +268,7 @@ Functions
parse_json
parse_xml
percent_rank
percentile_approx
percentile_cont
position
pow
Expand Down Expand Up @@ -350,6 +362,7 @@ Functions
udaf
udf
udtf
unbase64
uniform
unix_timestamp
upper
Expand Down
225 changes: 222 additions & 3 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
import functools
import sys
import typing
from functools import reduce
from random import randint
from types import ModuleType
from typing import Callable, Dict, List, Optional, Tuple, Union, overload
Expand Down Expand Up @@ -1345,6 +1346,9 @@ def approx_percentile(
)


percentile_approx = approx_percentile


@publicapi
def approx_percentile_accumulate(col: ColumnOrName, _emit_ast: bool = True) -> Column:
"""Returns the internal representation of the t-Digest state (as a JSON object) at the end of aggregation.
Expand Down Expand Up @@ -2956,6 +2960,63 @@ def log(
return builtin("log", _ast=ast, _emit_ast=False)(b, arg)


@publicapi
def log1p(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the natural logarithm of (1 + x).

Example::

>>> df = session.create_dataframe([0, 1], schema=["a"])
>>> df.select(log1p(df["a"]).alias("log1p")).collect()
[Row(LOG1P=0.0), Row(LOG1P=0.6931471805599453)]
"""
x = (
lit(x, _emit_ast=False)
if isinstance(x, (int, float))
else _to_col_if_str(x, "log")
)
one_plus_x = _to_col_if_str(x, "log1p") + lit(1, _emit_ast=_emit_ast)
return ln(one_plus_x, _emit_ast=_emit_ast)


@publicapi
def log10(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the base-10 logarithm of x.

Example::

>>> df = session.create_dataframe([1, 10], schema=["a"])
>>> df.select(log10(df["a"]).alias("log10")).collect()
[Row(LOG10=0.0), Row(LOG10=1.0)]
"""
return _log10(x, _emit_ast=_emit_ast)


@publicapi
def log2(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the base-2 logarithm of x.

Example::

>>> df = session.create_dataframe([1, 2, 8], schema=["a"])
>>> df.select(log2(df["a"]).alias("log2")).collect()
[Row(LOG2=0.0), Row(LOG2=1.0), Row(LOG2=3.0)]
"""
return _log2(x, _emit_ast=_emit_ast)


# Create base 2 and base 10 wrappers for use with the Modin log2 and log10 functions
def _log2(x: Union[ColumnOrName, int, float], _emit_ast: bool = True) -> Column:
return log(2, x, _emit_ast=_emit_ast)
Expand Down Expand Up @@ -7112,12 +7173,15 @@ def array_unique_agg(col: ColumnOrName, _emit_ast: bool = True) -> Column:


@publicapi
def map_cat(col1: ColumnOrName, col2: ColumnOrName, _emit_ast: bool = True):
"""Returns the concatenatation of two MAPs.
def map_cat(
col1: ColumnOrName, col2: ColumnOrName, *cols: ColumnOrName, _emit_ast: bool = True
):
"""Returns the concatenatation of two or more MAPs.

Args:
col1: The source map
col2: The map to be appended to col1
cols: More maps to be appended

Example::
>>> df = session.sql("select {'k1': 'v1'} :: MAP(STRING,STRING) as A, {'k2': 'v2'} :: MAP(STRING,STRING) as B")
Expand All @@ -7131,10 +7195,30 @@ def map_cat(col1: ColumnOrName, col2: ColumnOrName, _emit_ast: bool = True):
|} |
---------------------------
<BLANKLINE>
>>> df = session.sql("select {'k1': 'v1'} :: MAP(STRING,STRING) as A, {'k2': 'v2'} :: MAP(STRING,STRING) as B, {'k3': 'v3'} :: MAP(STRING,STRING) as C")
>>> df.select(map_cat("A", "B", "C")).show()
-------------------------------------------
|"MAP_CAT(MAP_CAT(""A"", ""B""), ""C"")" |
-------------------------------------------
|{ |
| "k1": "v1", |
| "k2": "v2", |
| "k3": "v3" |
|} |
-------------------------------------------
<BLANKLINE>
"""
m1 = _to_col_if_str(col1, "map_cat")
m2 = _to_col_if_str(col2, "map_cat")
return builtin("map_cat", _emit_ast=_emit_ast)(m1, m2)

def map_cat_two_maps(first, second):
return builtin("map_cat", _emit_ast=_emit_ast)(first, second)

cols_to_concat = [m1, m2]
for c in cols:
cols_to_concat.append(_to_col_if_str(c, "map_cat"))

return reduce(map_cat_two_maps, cols_to_concat)


@publicapi
Expand Down Expand Up @@ -11064,3 +11148,138 @@ def try_to_binary(
if fmt
else builtin("try_to_binary", _emit_ast=_emit_ast)(c)
)


@publicapi
def base64_encode(
e: ColumnOrName,
max_line_length: Optional[int] = 0,
alphabet: Optional[str] = None,
_emit_ast: bool = True,
) -> Column:
"""
Encodes the input (string or binary) using Base64 encoding.

Example:
>>> df = session.create_dataframe(["Snowflake", "Data"], schema=["input"])
>>> df.select(base64_encode(col("input")).alias("encoded")).collect()
[Row(ENCODED='U25vd2ZsYWtl'), Row(ENCODED='RGF0YQ==')]
"""
# Convert input to a column if it is not already one.
col_input = _to_col_if_str(e, "base64_encode")

# Prepare arguments for the function call.
args = [col_input]

if max_line_length:
args.append(lit(max_line_length))

if alphabet:
args.append(lit(alphabet))

# Call the built-in Base64 encode function.
return builtin("base64_encode", _emit_ast=_emit_ast)(*args)


base64 = base64_encode


@publicapi
def base64_decode_string(
e: ColumnOrName, alphabet: Optional[str] = None, _emit_ast: bool = True
) -> Column:
"""
Decodes a Base64-encoded string to a string.

Example:
>>> df = session.create_dataframe(["U25vd2ZsYWtl", "SEVMTE8="], schema=["input"])
>>> df.select(base64_decode_string(col("input")).alias("decoded")).collect()
[Row(DECODED='Snowflake'), Row(DECODED='HELLO')]
"""
# Convert input to a column if it is not already one.
col_input = _to_col_if_str(e, "base64_decode_string")

# Prepare arguments for the function call.
args = [col_input]

if alphabet:
args.append(lit(alphabet))

# Call the built-in Base64 encode function.
return builtin("base64_decode_string", _emit_ast=_emit_ast)(*args)


unbase64 = base64_decode_string


@publicapi
def hex_encode(e: ColumnOrName, case: int = 1, _emit_ast: bool = True):
"""
Encodes the input using hexadecimal (also ‘hex’ or ‘base16’) encoding.

Example:
>>> df = session.create_dataframe(["Snowflake", "Hello"], schema=["input"])
>>> df.select(hex_encode(col("input")).alias("hex_encoded")).collect()
[Row(HEX_ENCODED='536E6F77666C616B65'), Row(HEX_ENCODED='48656C6C6F')]
"""
col_input = _to_col_if_str(e, "base64_decode_string")
return builtin("hex_encode", _emit_ast=_emit_ast)(col_input, lit(case))


hex = hex_encode


@publicapi
def editdistance(
string_expr1: ColumnOrName,
string_expr2: ColumnOrName,
max_distance: Optional[int] = None,
_emit_ast: bool = True,
) -> Column:
"""Computes the Levenshtein distance between two input strings.

Optionally, a maximum distance can be specified. If the distance exceeds this value,
the computation halts and returns the maximum distance.

Example::

>>> df = session.create_dataframe(
... [["abc", "def"], ["abcdef", "abc"], ["snow", "flake"]],
... schema=["s1", "s2"]
... )
>>> df.select(
... editdistance(col("s1"), col("s2")).alias("distance"),
... editdistance(col("s1"), col("s2"), 2).alias("max_2_distance")
... ).collect()
[Row(DISTANCE=3, MAX_2_DISTANCE=2), Row(DISTANCE=3, MAX_2_DISTANCE=2), Row(DISTANCE=5, MAX_2_DISTANCE=2)]
"""
s1 = _to_col_if_str(string_expr1, "editdistance")
s2 = _to_col_if_str(string_expr2, "editdistance")

args = [s1, s2]
if max_distance is not None:
max_dist = (
max_distance
if isinstance(max_distance, Column)
else lit(max_distance, _emit_ast=_emit_ast)
)
args.append(max_dist)

return builtin("editdistance", _emit_ast=False)(*args)


levenshtein = editdistance


@publicapi
def instr(str: ColumnOrName, substr: str):
"""
Locate the position of the first occurrence of substr column in the given string. Returns null if either of the arguments are null.

Example::
>>> df = session.create_dataframe([["hello world"], ["world hello"]], schema=["text"])
>>> df.select(instr(col("text"), "world").alias("position")).collect()
[Row(POSITION=7), Row(POSITION=1)]
"""
s1 = _to_col_if_str(str, "instr")
return position(lit(substr), s1)
Loading