diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 0a8283f67756..89adf57b3bdf 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -460,3 +460,45 @@ def split_braces(format_str: str) -> List[str]: prev = '' ret_list.append(tmp_str) return ret_list + + +@specialize_function('join', str_rprimitive) +def translate_fstring( + builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + # Special case for f-string, which is translated into str.join() in mypy AST. + # This specializer optimizes simplest f-strings which don't contain any + # format operation. + if (isinstance(callee, MemberExpr) + and isinstance(callee.expr, StrExpr) and callee.expr.value == '' + and expr.arg_kinds == [ARG_POS] and isinstance(expr.args[0], ListExpr)): + for item in expr.args[0].items: + if isinstance(item, StrExpr): + continue + elif isinstance(item, CallExpr): + if (not isinstance(item.callee, MemberExpr) + or item.callee.name != 'format'): + return None + elif (not isinstance(item.callee.expr, StrExpr) + or item.callee.expr.value != '{:{}}'): + return None + + if not isinstance(item.args[1], StrExpr) or item.args[1].value != '': + return None + else: + return None + + result_list: List[Value] = [Integer(0, c_pyssize_t_rprimitive)] + for item in expr.args[0].items: + if isinstance(item, StrExpr) and item.value != '': + result_list.append(builder.accept(item)) + elif isinstance(item, CallExpr): + result_list.append(builder.call_c(str_op, + [builder.accept(item.args[0])], + expr.line)) + + if len(result_list) == 1: + return builder.load_str("") + + result_list[0] = Integer(len(result_list) - 1, c_pyssize_t_rprimitive) + return builder.call_c(str_build_op, result_list, expr.line) + return None diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 4bc8948b6408..e0f9f0592b2f 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -191,3 +191,56 @@ L0: s3 = r13 return 1 +[case testFStrings] +def f(var: str, num: int) -> None: + s1 = f"Hi! I'm {var}. I am {num} years old." + s2 = f'Hello {var:>{num}}' + s3 = f'' + s4 = f'abc' +[out] +def f(var, num): + var :: str + num :: int + r0, r1, r2 :: str + r3 :: object + r4, r5, r6, s1, r7, r8, r9, r10 :: str + r11 :: object + r12, r13, r14 :: str + r15 :: object + r16 :: str + r17 :: list + r18, r19, r20 :: ptr + r21, s2, r22, s3, r23, s4 :: str +L0: + r0 = "Hi! I'm " + r1 = PyObject_Str(var) + r2 = '. I am ' + r3 = box(int, num) + r4 = PyObject_Str(r3) + r5 = ' years old.' + r6 = CPyStr_Build(5, r0, r1, r2, r4, r5) + s1 = r6 + r7 = '' + r8 = 'Hello ' + r9 = '{:{}}' + r10 = '>' + r11 = box(int, num) + r12 = PyObject_Str(r11) + r13 = CPyStr_Build(2, r10, r12) + r14 = 'format' + r15 = CPyObject_CallMethodObjArgs(r9, r14, var, r13, 0) + r16 = cast(str, r15) + r17 = PyList_New(2) + r18 = get_element_ptr r17 ob_item :: PyListObject + r19 = load_mem r18 :: ptr* + set_mem r19, r8 :: builtins.object* + r20 = r19 + WORD_SIZE*1 + set_mem r20, r16 :: builtins.object* + keep_alive r17 + r21 = PyUnicode_Join(r7, r17) + s2 = r21 + r22 = '' + s3 = r22 + r23 = 'abc' + s4 = r23 + return 1 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index edcf682d3d3f..30c152961611 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -192,6 +192,21 @@ def test_fstring_basics() -> None: inf_num = float('inf') assert f'{nan_num}, {inf_num}' == 'nan, inf' +# F-strings would be translated into ''.join[string literals, format method call, ...] in mypy AST. +# Currently we are using a str.join specializer for f-string speed up. We might not cover all cases +# and the rest ones should fall back to a normal str.join method call. +# TODO: Once we have a new pipeline for f-strings, this test case can be moved to testStringOps. +def test_str_join() -> None: + var = 'mypyc' + num = 10 + assert ''.join(['a', 'b', '{}'.format(var), 'c']) == 'abmypycc' + assert ''.join(['a', 'b', '{:{}}'.format(var, ''), 'c']) == 'abmypycc' + assert ''.join(['a', 'b', '{:{}}'.format(var, '>10'), 'c']) == 'ab mypycc' + assert ''.join(['a', 'b', '{:{}}'.format(var, '>{}'.format(num)), 'c']) == 'ab mypycc' + assert var.join(['a', '{:{}}'.format(var, ''), 'b']) == 'amypycmypycmypycb' + assert ','.join(['a', '{:{}}'.format(var, ''), 'b']) == 'a,mypyc,b' + assert ''.join(['x', var]) == 'xmypyc' + class A: def __init__(self, name, age): self.name = name @@ -356,6 +371,13 @@ def test_format_method_different_kind() -> None: assert "Test: {}{}".format(s3, s1) == "Test: 测试:Literal['😀']" assert "Test: {}{}".format(s3, s2) == "Test: 测试:Revealed type is" +def test_format_method_nested() -> None: + var = 'mypyc' + num = 10 + assert '{:{}}'.format(var, '') == 'mypyc' + assert '{:{}}'.format(var, '>10') == ' mypyc' + assert '{:{}}'.format(var, '>{}'.format(num)) == ' mypyc' + class Point: def __init__(self, x, y): self.x, self.y = x, y