Skip to content

Commit 1598f2f

Browse files
committed
address comments
1 parent c83b879 commit 1598f2f

File tree

8 files changed

+46
-13
lines changed

8 files changed

+46
-13
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525

26+
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
27+
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
28+
29+
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
30+
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
31+
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
32+
_FLOAT64_EXP_BOUND = sge.convert(709.78)
33+
2634
UNARY_OP_REGISTRATION = OpRegistration()
2735

2836

@@ -36,7 +44,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
3644
ifs=[
3745
sge.If(
3846
this=sge.func("ABS", expr.expr) > sge.convert(1),
39-
true=sge.func("IEEE_DIVIDE", sge.convert(0), sge.convert(0)),
47+
true=_NAN,
4048
)
4149
],
4250
default=sge.func("ACOS", expr.expr),
@@ -49,7 +57,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
4957
ifs=[
5058
sge.If(
5159
this=sge.func("ABS", expr.expr) > sge.convert(1),
52-
true=sge.func("IEEE_DIVIDE", sge.convert(0), sge.convert(0)),
60+
true=_NAN,
5361
)
5462
],
5563
default=sge.func("ASIN", expr.expr),
@@ -105,7 +113,7 @@ def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression:
105113

106114
@UNARY_OP_REGISTRATION.register(ops.cos_op)
107115
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
108-
return sge.func("cos", expr.expr)
116+
return sge.func("COS", expr.expr)
109117

110118

111119
@UNARY_OP_REGISTRATION.register(ops.hash_op)
@@ -125,17 +133,16 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
125133

126134
@UNARY_OP_REGISTRATION.register(ops.sin_op)
127135
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
128-
return sge.func("sin", expr.expr)
136+
return sge.func("SIN", expr.expr)
129137

130138

131139
@UNARY_OP_REGISTRATION.register(ops.sinh_op)
132140
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
133141
return sge.Case(
134142
ifs=[
135143
sge.If(
136-
this=sge.func("ABS", expr.expr) > sge.convert(709.78),
137-
true=sge.func("SIGN", expr.expr)
138-
* sge.func("IEEE_DIVIDE", sge.convert(1), sge.convert(0)),
144+
this=sge.func("ABS", expr.expr) > _FLOAT64_EXP_BOUND,
145+
true=sge.func("SIGN", expr.expr) * _INF,
139146
)
140147
],
141148
default=sge.func("SINH", expr.expr),
@@ -190,4 +197,4 @@ def _(op: ops.ParseJSON, expr: TypedExpr) -> sge.Expression:
190197

191198
@UNARY_OP_REGISTRATION.register(ops.ToJSONString)
192199
def _(op: ops.ToJSONString, expr: TypedExpr) -> sge.Expression:
193-
return sge.func("TO_JSON_STRING", expr.expr)
200+
return sge.func("TO_JSON_STRING", expr.expr)

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_arccos/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE WHEN ABS(`bfcol_0`) > 1 THEN IEEE_DIVIDE(0, 0) ELSE ACOS(`bfcol_0`) END AS `bfcol_1`
8+
CASE WHEN ABS(`bfcol_0`) > 1 THEN CAST('NaN' AS FLOAT64) ELSE ACOS(`bfcol_0`) END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_arcsin/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE WHEN ABS(`bfcol_0`) > 1 THEN IEEE_DIVIDE(0, 0) ELSE ASIN(`bfcol_0`) END AS `bfcol_1`
8+
CASE WHEN ABS(`bfcol_0`) > 1 THEN CAST('NaN' AS FLOAT64) ELSE ASIN(`bfcol_0`) END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_cos/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
cos(`bfcol_0`) AS `bfcol_1`
8+
COS(`bfcol_0`) AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_sin/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
sin(`bfcol_0`) AS `bfcol_1`
8+
SIN(`bfcol_0`) AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_sinh/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
CASE
99
WHEN ABS(`bfcol_0`) > 709.78
10-
THEN SIGN(`bfcol_0`) * IEEE_DIVIDE(1, 0)
10+
THEN SIGN(`bfcol_0`) * CAST('Infinity' AS FLOAT64)
1111
ELSE SINH(`bfcol_0`)
1212
END AS `bfcol_1`
1313
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,21 @@ def _apply_binary_op(
4444
def test_add_numeric(scalar_types_df: bpd.DataFrame, snapshot):
4545
bf_df = scalar_types_df[["int64_col"]]
4646
sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", "int64_col")
47+
4748
snapshot.assert_match(sql, "out.sql")
4849

4950

5051
def test_add_numeric_w_scalar(scalar_types_df: bpd.DataFrame, snapshot):
5152
bf_df = scalar_types_df[["int64_col"]]
5253
sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", ex.const(1))
54+
5355
snapshot.assert_match(sql, "out.sql")
5456

5557

5658
def test_add_string(scalar_types_df: bpd.DataFrame, snapshot):
5759
bf_df = scalar_types_df[["string_col"]]
5860
sql = _apply_binary_op(bf_df, ops.add_op, "string_col", ex.const("a"))
61+
5962
snapshot.assert_match(sql, "out.sql")
6063

6164

@@ -64,4 +67,5 @@ def test_json_set(json_types_df: bpd.DataFrame, snapshot):
6467
sql = _apply_binary_op(
6568
bf_df, ops.JSONSet(json_path="$.a"), "json_col", ex.const(100)
6669
)
70+
6771
snapshot.assert_match(sql, "out.sql")

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,130 +37,152 @@ def _apply_unary_op(obj: bpd.DataFrame, op: ops.UnaryOp, arg: str) -> str:
3737
def test_arccos(scalar_types_df: bpd.DataFrame, snapshot):
3838
bf_df = scalar_types_df[["float64_col"]]
3939
sql = _apply_unary_op(bf_df, ops.arccos_op, "float64_col")
40+
4041
snapshot.assert_match(sql, "out.sql")
4142

4243

4344
def test_arcsin(scalar_types_df: bpd.DataFrame, snapshot):
4445
bf_df = scalar_types_df[["float64_col"]]
4546
sql = _apply_unary_op(bf_df, ops.arcsin_op, "float64_col")
47+
4648
snapshot.assert_match(sql, "out.sql")
4749

4850

4951
def test_arctan(scalar_types_df: bpd.DataFrame, snapshot):
5052
bf_df = scalar_types_df[["float64_col"]]
5153
sql = _apply_unary_op(bf_df, ops.arctan_op, "float64_col")
54+
5255
snapshot.assert_match(sql, "out.sql")
5356

5457

5558
def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot):
5659
bf_df = repeated_types_df[["string_list_col"]]
5760
sql = _apply_unary_op(bf_df, ops.ArrayToStringOp(delimiter="."), "string_list_col")
61+
5862
snapshot.assert_match(sql, "out.sql")
5963

6064

6165
def test_array_index(repeated_types_df: bpd.DataFrame, snapshot):
6266
bf_df = repeated_types_df[["string_list_col"]]
6367
sql = _apply_unary_op(bf_df, convert_index(1), "string_list_col")
68+
6469
snapshot.assert_match(sql, "out.sql")
6570

6671

6772
def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot):
6873
bf_df = repeated_types_df[["string_list_col"]]
6974
sql = _apply_unary_op(bf_df, convert_slice(slice(1, None)), "string_list_col")
75+
7076
snapshot.assert_match(sql, "out.sql")
7177

7278

7379
def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snapshot):
7480
bf_df = repeated_types_df[["string_list_col"]]
7581
sql = _apply_unary_op(bf_df, convert_slice(slice(1, 5)), "string_list_col")
82+
7683
snapshot.assert_match(sql, "out.sql")
7784

7885

7986
def test_cos(scalar_types_df: bpd.DataFrame, snapshot):
8087
bf_df = scalar_types_df[["float64_col"]]
8188
sql = _apply_unary_op(bf_df, ops.cos_op, "float64_col")
89+
8290
snapshot.assert_match(sql, "out.sql")
8391

8492

8593
def test_hash(scalar_types_df: bpd.DataFrame, snapshot):
8694
bf_df = scalar_types_df[["string_col"]]
8795
sql = _apply_unary_op(bf_df, ops.hash_op, "string_col")
96+
8897
snapshot.assert_match(sql, "out.sql")
8998

9099

91100
def test_isnull(scalar_types_df: bpd.DataFrame, snapshot):
92101
bf_df = scalar_types_df[["float64_col"]]
93102
sql = _apply_unary_op(bf_df, ops.isnull_op, "float64_col")
103+
94104
snapshot.assert_match(sql, "out.sql")
95105

96106

97107
def test_notnull(scalar_types_df: bpd.DataFrame, snapshot):
98108
bf_df = scalar_types_df[["float64_col"]]
99109
sql = _apply_unary_op(bf_df, ops.notnull_op, "float64_col")
110+
100111
snapshot.assert_match(sql, "out.sql")
101112

102113

103114
def test_sin(scalar_types_df: bpd.DataFrame, snapshot):
104115
bf_df = scalar_types_df[["float64_col"]]
105116
sql = _apply_unary_op(bf_df, ops.sin_op, "float64_col")
117+
106118
snapshot.assert_match(sql, "out.sql")
107119

108120

109121
def test_sinh(scalar_types_df: bpd.DataFrame, snapshot):
110122
bf_df = scalar_types_df[["float64_col"]]
111123
sql = _apply_unary_op(bf_df, ops.sinh_op, "float64_col")
124+
112125
snapshot.assert_match(sql, "out.sql")
113126

114127

115128
def test_tan(scalar_types_df: bpd.DataFrame, snapshot):
116129
bf_df = scalar_types_df[["float64_col"]]
117130
sql = _apply_unary_op(bf_df, ops.tan_op, "float64_col")
131+
118132
snapshot.assert_match(sql, "out.sql")
119133

120134

121135
def test_json_extract(json_types_df: bpd.DataFrame, snapshot):
122136
bf_df = json_types_df[["json_col"]]
123137
sql = _apply_unary_op(bf_df, ops.JSONExtract(json_path="$"), "json_col")
138+
124139
snapshot.assert_match(sql, "out.sql")
125140

126141

127142
def test_json_extract_array(json_types_df: bpd.DataFrame, snapshot):
128143
bf_df = json_types_df[["json_col"]]
129144
sql = _apply_unary_op(bf_df, ops.JSONExtractArray(json_path="$"), "json_col")
145+
130146
snapshot.assert_match(sql, "out.sql")
131147

132148

133149
def test_json_extract_string_array(json_types_df: bpd.DataFrame, snapshot):
134150
bf_df = json_types_df[["json_col"]]
135151
sql = _apply_unary_op(bf_df, ops.JSONExtractStringArray(json_path="$"), "json_col")
152+
136153
snapshot.assert_match(sql, "out.sql")
137154

138155

139156
def test_json_query(json_types_df: bpd.DataFrame, snapshot):
140157
bf_df = json_types_df[["json_col"]]
141158
sql = _apply_unary_op(bf_df, ops.JSONQuery(json_path="$"), "json_col")
159+
142160
snapshot.assert_match(sql, "out.sql")
143161

144162

145163
def test_json_query_array(json_types_df: bpd.DataFrame, snapshot):
146164
bf_df = json_types_df[["json_col"]]
147165
sql = _apply_unary_op(bf_df, ops.JSONQueryArray(json_path="$"), "json_col")
166+
148167
snapshot.assert_match(sql, "out.sql")
149168

150169

151170
def test_json_value(json_types_df: bpd.DataFrame, snapshot):
152171
bf_df = json_types_df[["json_col"]]
153172
sql = _apply_unary_op(bf_df, ops.JSONValue(json_path="$"), "json_col")
173+
154174
snapshot.assert_match(sql, "out.sql")
155175

156176

157177
def test_parse_json(scalar_types_df: bpd.DataFrame, snapshot):
158178
bf_df = scalar_types_df[["string_col"]]
159179
sql = _apply_unary_op(bf_df, ops.ParseJSON(), "string_col")
180+
160181
snapshot.assert_match(sql, "out.sql")
161182

162183

163184
def test_to_json_string(json_types_df: bpd.DataFrame, snapshot):
164185
bf_df = json_types_df[["json_col"]]
165186
sql = _apply_unary_op(bf_df, ops.ToJSONString(), "json_col")
187+
166188
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)