diff --git a/src/sql_join_2/__init__.py b/src/sql_join_2/__init__.py index 093063e..2d7d780 100644 --- a/src/sql_join_2/__init__.py +++ b/src/sql_join_2/__init__.py @@ -15,20 +15,6 @@ # 是否自动缓存结果 cacheable = True - -LEARNING_ALGORITHMS = { - "rank": "lambdarank", - "lambdarank": "lambdarank", - "regression": "regression", - "binaryclassification": "binaryclassification", - "排序": "lambdarank", - "回归": "regression", - "二分类": "binaryclassification", - "logloss": "logloss", -} -FAI_CLUSTERS = {"不加速": None} - - SQL_JOIN = """WITH sql1 AS ( {sql1} @@ -49,14 +35,25 @@ def run( """DAI 合并SQL。""" import dai + if not sql1 or not sql2: + raise ValueError("sql1 和 sql2 不能为空") + + sql1 = sql1.read() + if isinstance(sql1, dict): + sql1 = sql1["sql"] + + sql2 = sql2.read() + if isinstance(sql2, dict): + sql2 = sql2["sql"] + # 拆分features_sql和label - sql_statements_1 = dai._functions.parse_query(sql1.read_text()) - sql_statements_2 = dai._functions.parse_query(sql2.read_text()) + sql_statements_1 = dai._functions.parse_query(sql1) + sql_statements_2 = dai._functions.parse_query(sql2) join_sql = sql_join.format(sql1=sql_statements_1.pop(), sql2=sql_statements_2.pop()) join_sql = ";\n".join(list(set(sql_statements_1 + sql_statements_2))) + ";\n" + join_sql - data = dai.DataSource.write_text(join_sql) + data = dai.DataSource.write_json({"sql": join_sql}) return I.Outputs(data=data)