diff --git a/crates/pyrefly_python/src/dunder.rs b/crates/pyrefly_python/src/dunder.rs index 8fd888e3b..7d39ca46f 100644 --- a/crates/pyrefly_python/src/dunder.rs +++ b/crates/pyrefly_python/src/dunder.rs @@ -26,6 +26,7 @@ pub const GET: Name = Name::new_static("__get__"); pub const GETATTR: Name = Name::new_static("__getattr__"); pub const GETATTRIBUTE: Name = Name::new_static("__getattribute__"); pub const GETITEM: Name = Name::new_static("__getitem__"); +pub const CLASS_GETITEM: Name = Name::new_static("__class_getitem__"); pub const GT: Name = Name::new_static("__gt__"); pub const HASH: Name = Name::new_static("__hash__"); pub const INIT: Name = Name::new_static("__init__"); diff --git a/pyrefly/lib/alt/expr.rs b/pyrefly/lib/alt/expr.rs index 3413de837..6fed5e8f6 100644 --- a/pyrefly/lib/alt/expr.rs +++ b/pyrefly/lib/alt/expr.rs @@ -1855,12 +1855,42 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } } - Type::ClassDef(cls) => Type::type_form(self.specialize( - &cls, - xs.map(|x| self.expr_untype(x, TypeFormContext::TypeArgument, errors)), - range, - errors, - )), + Type::ClassDef(cls) => { + let metadata = self.get_metadata_for_class(&cls); + let class_getitem_result = if self.get_class_tparams(&cls).is_empty() + && !metadata.has_base_any() + && !metadata.is_new_type() + { + let class_ty = Type::ClassDef(cls.dupe()); + if self.has_attr(&class_ty, &dunder::CLASS_GETITEM) { + let cls_value = self.promote_silently(&cls); + let call_args = [CallArg::ty(&cls_value, range), CallArg::expr(slice)]; + Some(self.call_method_or_error( + &class_ty, + &dunder::CLASS_GETITEM, + range, + &call_args, + &[], + errors, + Some(&|| ErrorContext::Index(self.for_display(class_ty.clone()))), + )) + } else { + None + } + } else { + None + }; + if let Some(result) = class_getitem_result { + result + } else { + Type::type_form(self.specialize( + &cls, + xs.map(|x| self.expr_untype(x, TypeFormContext::TypeArgument, errors)), + range, + errors, + )) + } + } Type::Type(box Type::SpecialForm(special)) => { self.apply_special_form(special, slice, range, errors) } diff --git a/pyrefly/lib/test/simple.rs b/pyrefly/lib/test/simple.rs index 63f8542a2..802f7285d 100644 --- a/pyrefly/lib/test/simple.rs +++ b/pyrefly/lib/test/simple.rs @@ -1737,6 +1737,19 @@ def f(condition: bool): "#, ); +testcase!( + test_class_getitem_magic_dunder, + r#" +from typing import assert_type + +class Foo: + def __class_getitem__(cls, item: int) -> str: + return str(item) + +assert_type(Foo[0], str) +"#, +); + testcase!(test_panic_docstring, "\"\"\" F\n\u{85}\"\"\"",); testcase!(