mirror of
https://github.com/charliermarsh/ruff
synced 2025-10-05 23:52:47 +02:00
[ty] Use annotated parameters as type context (#20635)
## Summary Use the type annotation of function parameters as bidirectional type context when inferring the argument expression. For example, the following example now type-checks: ```py class TD(TypedDict): x: int def f(_: TD): ... f({ "x": 1 }) ``` Part of https://github.com/astral-sh/ty/issues/168.
This commit is contained in:
@@ -117,7 +117,7 @@ static COLOUR_SCIENCE: std::sync::LazyLock<Benchmark<'static>> = std::sync::Lazy
|
|||||||
max_dep_date: "2025-06-17",
|
max_dep_date: "2025-06-17",
|
||||||
python_version: PythonVersion::PY310,
|
python_version: PythonVersion::PY310,
|
||||||
},
|
},
|
||||||
500,
|
600,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@@ -1662,3 +1662,67 @@ def _(arg: tuple[A | B, Any]):
|
|||||||
reveal_type(f(arg)) # revealed: Unknown
|
reveal_type(f(arg)) # revealed: Unknown
|
||||||
reveal_type(f(*(arg,))) # revealed: Unknown
|
reveal_type(f(*(arg,))) # revealed: Unknown
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bidirectional Type Inference
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
Type inference accounts for parameter type annotations across all overloads.
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import TypedDict, overload
|
||||||
|
|
||||||
|
class T(TypedDict):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def f(a: list[T], b: int) -> int: ...
|
||||||
|
@overload
|
||||||
|
def f(a: list[dict[str, int]], b: str) -> str: ...
|
||||||
|
def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def int_or_str() -> int | str:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
x = f([{"x": 1}], int_or_str())
|
||||||
|
reveal_type(x) # revealed: int | str
|
||||||
|
|
||||||
|
# TODO: error: [no-matching-overload] "No overload of function `f` matches arguments"
|
||||||
|
# we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]`
|
||||||
|
f([{"y": 1}], int_or_str())
|
||||||
|
```
|
||||||
|
|
||||||
|
Non-matching overloads do not produce diagnostics:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import TypedDict, overload
|
||||||
|
|
||||||
|
class T(TypedDict):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def f(a: T, b: int) -> int: ...
|
||||||
|
@overload
|
||||||
|
def f(a: dict[str, int], b: str) -> str: ...
|
||||||
|
def f(a: T | dict[str, int], b: int | str) -> int | str:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
x = f({"y": 1}, "a")
|
||||||
|
reveal_type(x) # revealed: str
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import SupportsRound, overload
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def takes_str_or_float(x: str): ...
|
||||||
|
@overload
|
||||||
|
def takes_str_or_float(x: float): ...
|
||||||
|
def takes_str_or_float(x: float | str): ...
|
||||||
|
|
||||||
|
takes_str_or_float(round(1.0))
|
||||||
|
```
|
||||||
|
@@ -251,3 +251,59 @@ from ty_extensions import Intersection, Not
|
|||||||
def _(x: Union[Intersection[Any, Not[int]], Intersection[Any, Not[int]]]):
|
def _(x: Union[Intersection[Any, Not[int]], Intersection[Any, Not[int]]]):
|
||||||
reveal_type(x) # revealed: Any & ~int
|
reveal_type(x) # revealed: Any & ~int
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bidirectional Type Inference
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
Type inference accounts for parameter type annotations across all signatures in a union.
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import TypedDict, overload
|
||||||
|
|
||||||
|
class T(TypedDict):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
def _(flag: bool):
|
||||||
|
if flag:
|
||||||
|
def f(x: T) -> int:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
def f(x: dict[str, int]) -> int:
|
||||||
|
return 1
|
||||||
|
x = f({"x": 1})
|
||||||
|
reveal_type(x) # revealed: int
|
||||||
|
|
||||||
|
# TODO: error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `T`, found `dict[str, int]`"
|
||||||
|
# we currently consider `TypedDict` instances to be subtypes of `dict`
|
||||||
|
f({"y": 1})
|
||||||
|
```
|
||||||
|
|
||||||
|
Diagnostics unrelated to the type-context are only reported once:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def f[T](x: T) -> list[T]:
|
||||||
|
return [x]
|
||||||
|
|
||||||
|
def a(x: list[bool], y: list[bool]): ...
|
||||||
|
def b(x: list[int], y: list[int]): ...
|
||||||
|
def c(x: list[int], y: list[int]): ...
|
||||||
|
def _(x: int):
|
||||||
|
if x == 0:
|
||||||
|
y = a
|
||||||
|
elif x == 1:
|
||||||
|
y = b
|
||||||
|
else:
|
||||||
|
y = c
|
||||||
|
|
||||||
|
if x == 0:
|
||||||
|
z = True
|
||||||
|
|
||||||
|
y(f(True), [True])
|
||||||
|
|
||||||
|
# error: [possibly-unresolved-reference] "Name `z` used when possibly not defined"
|
||||||
|
y(f(True), [z])
|
||||||
|
```
|
||||||
|
@@ -10,6 +10,7 @@ from typing_extensions import assert_type
|
|||||||
def _(x: int):
|
def _(x: int):
|
||||||
assert_type(x, int) # fine
|
assert_type(x, int) # fine
|
||||||
assert_type(x, str) # error: [type-assertion-failure]
|
assert_type(x, str) # error: [type-assertion-failure]
|
||||||
|
assert_type(assert_type(x, int), int)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Narrowing
|
## Narrowing
|
||||||
|
@@ -17,6 +17,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/directives/assert_type.m
|
|||||||
3 | def _(x: int):
|
3 | def _(x: int):
|
||||||
4 | assert_type(x, int) # fine
|
4 | assert_type(x, int) # fine
|
||||||
5 | assert_type(x, str) # error: [type-assertion-failure]
|
5 | assert_type(x, str) # error: [type-assertion-failure]
|
||||||
|
6 | assert_type(assert_type(x, int), int)
|
||||||
```
|
```
|
||||||
|
|
||||||
# Diagnostics
|
# Diagnostics
|
||||||
@@ -31,6 +32,7 @@ error[type-assertion-failure]: Argument does not have asserted type `str`
|
|||||||
| ^^^^^^^^^^^^-^^^^^^
|
| ^^^^^^^^^^^^-^^^^^^
|
||||||
| |
|
| |
|
||||||
| Inferred type of argument is `int`
|
| Inferred type of argument is `int`
|
||||||
|
6 | assert_type(assert_type(x, int), int)
|
||||||
|
|
|
|
||||||
info: `str` and `int` are not equivalent types
|
info: `str` and `int` are not equivalent types
|
||||||
info: rule `type-assertion-failure` is enabled by default
|
info: rule `type-assertion-failure` is enabled by default
|
||||||
|
@@ -152,7 +152,7 @@ Person(name="Alice")
|
|||||||
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
|
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
|
||||||
Person({"name": "Alice"})
|
Person({"name": "Alice"})
|
||||||
|
|
||||||
# TODO: this should be an error, similar to the above
|
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
|
||||||
accepts_person({"name": "Alice"})
|
accepts_person({"name": "Alice"})
|
||||||
# TODO: this should be an error, similar to the above
|
# TODO: this should be an error, similar to the above
|
||||||
house.owner = {"name": "Alice"}
|
house.owner = {"name": "Alice"}
|
||||||
@@ -171,7 +171,7 @@ Person(name=None, age=30)
|
|||||||
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
|
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
|
||||||
Person({"name": None, "age": 30})
|
Person({"name": None, "age": 30})
|
||||||
|
|
||||||
# TODO: this should be an error, similar to the above
|
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
|
||||||
accepts_person({"name": None, "age": 30})
|
accepts_person({"name": None, "age": 30})
|
||||||
# TODO: this should be an error, similar to the above
|
# TODO: this should be an error, similar to the above
|
||||||
house.owner = {"name": None, "age": 30}
|
house.owner = {"name": None, "age": 30}
|
||||||
@@ -190,7 +190,7 @@ Person(name="Alice", age=30, extra=True)
|
|||||||
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
|
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
|
||||||
Person({"name": "Alice", "age": 30, "extra": True})
|
Person({"name": "Alice", "age": 30, "extra": True})
|
||||||
|
|
||||||
# TODO: this should be an error
|
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
|
||||||
accepts_person({"name": "Alice", "age": 30, "extra": True})
|
accepts_person({"name": "Alice", "age": 30, "extra": True})
|
||||||
# TODO: this should be an error
|
# TODO: this should be an error
|
||||||
house.owner = {"name": "Alice", "age": 30, "extra": True}
|
house.owner = {"name": "Alice", "age": 30, "extra": True}
|
||||||
|
@@ -4194,20 +4194,26 @@ impl<'db> Type<'db> {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(KnownFunction::AssertType) => Binding::single(
|
Some(KnownFunction::AssertType) => {
|
||||||
self,
|
let val_ty =
|
||||||
Signature::new(
|
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Invariant);
|
||||||
Parameters::new([
|
|
||||||
Parameter::positional_only(Some(Name::new_static("value")))
|
Binding::single(
|
||||||
.with_annotated_type(Type::any()),
|
self,
|
||||||
Parameter::positional_only(Some(Name::new_static("type")))
|
Signature::new_generic(
|
||||||
.type_form()
|
Some(GenericContext::from_typevar_instances(db, [val_ty])),
|
||||||
.with_annotated_type(Type::any()),
|
Parameters::new([
|
||||||
]),
|
Parameter::positional_only(Some(Name::new_static("value")))
|
||||||
Some(Type::none(db)),
|
.with_annotated_type(Type::TypeVar(val_ty)),
|
||||||
),
|
Parameter::positional_only(Some(Name::new_static("type")))
|
||||||
)
|
.type_form()
|
||||||
.into(),
|
.with_annotated_type(Type::any()),
|
||||||
|
]),
|
||||||
|
Some(Type::TypeVar(val_ty)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
Some(KnownFunction::AssertNever) => {
|
Some(KnownFunction::AssertNever) => {
|
||||||
Binding::single(
|
Binding::single(
|
||||||
|
@@ -1077,7 +1077,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
|
|||||||
// don't need to worry about finding any particular constraint more than once.
|
// don't need to worry about finding any particular constraint more than once.
|
||||||
let constraints = constraints.elements(db);
|
let constraints = constraints.elements(db);
|
||||||
let mut positive_constraint_count = 0;
|
let mut positive_constraint_count = 0;
|
||||||
for positive in &self.positive {
|
for (i, positive) in self.positive.iter().enumerate() {
|
||||||
|
if i == typevar_index {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// This linear search should be fine as long as we don't encounter typevars with
|
// This linear search should be fine as long as we don't encounter typevars with
|
||||||
// thousands of constraints.
|
// thousands of constraints.
|
||||||
positive_constraint_count += constraints
|
positive_constraint_count += constraints
|
||||||
|
@@ -33,10 +33,10 @@ use crate::types::{
|
|||||||
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
||||||
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
||||||
TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
|
TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
|
||||||
WrapperDescriptorKind, enums, ide_support, todo_type,
|
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
|
||||||
};
|
};
|
||||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||||
use ruff_python_ast::{self as ast, PythonVersion};
|
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
||||||
|
|
||||||
/// Binding information for a possible union of callables. At a call site, the arguments must be
|
/// Binding information for a possible union of callables. At a call site, the arguments must be
|
||||||
/// compatible with _all_ of the types in the union for the call to be valid.
|
/// compatible with _all_ of the types in the union for the call to be valid.
|
||||||
@@ -1776,7 +1776,7 @@ impl<'db> CallableBinding<'db> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
|
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
|
||||||
fn matching_overload_index(&self) -> MatchingOverloadIndex {
|
pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex {
|
||||||
let mut matching_overloads = self.matching_overloads();
|
let mut matching_overloads = self.matching_overloads();
|
||||||
match matching_overloads.next() {
|
match matching_overloads.next() {
|
||||||
None => MatchingOverloadIndex::None,
|
None => MatchingOverloadIndex::None,
|
||||||
@@ -1794,8 +1794,15 @@ impl<'db> CallableBinding<'db> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns all overloads for this call binding, including overloads that did not match.
|
||||||
|
pub(crate) fn overloads(&self) -> &[Binding<'db>] {
|
||||||
|
self.overloads.as_slice()
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all the overloads that matched for this call binding.
|
/// Returns an iterator over all the overloads that matched for this call binding.
|
||||||
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
|
pub(crate) fn matching_overloads(
|
||||||
|
&self,
|
||||||
|
) -> impl Iterator<Item = (usize, &Binding<'db>)> + Clone {
|
||||||
self.overloads
|
self.overloads
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -2026,7 +2033,7 @@ enum OverloadCallReturnType<'db> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum MatchingOverloadIndex {
|
pub(crate) enum MatchingOverloadIndex {
|
||||||
/// No matching overloads found.
|
/// No matching overloads found.
|
||||||
None,
|
None,
|
||||||
|
|
||||||
@@ -2504,9 +2511,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||||||
if let Some(return_ty) = self.signature.return_ty
|
if let Some(return_ty) = self.signature.return_ty
|
||||||
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
|
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
|
||||||
{
|
{
|
||||||
// Ignore any specialization errors here, because the type context is only used to
|
match call_expression_tcx {
|
||||||
// optionally widen the return type.
|
// A type variable is not a useful type-context for expression inference, and applying it
|
||||||
let _ = builder.infer(return_ty, call_expression_tcx);
|
// to the return type can lead to confusing unions in nested generic calls.
|
||||||
|
Type::TypeVar(_) => {}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
// Ignore any specialization errors here, because the type context is only used as a hint
|
||||||
|
// to infer a more assignable return type.
|
||||||
|
let _ = builder.infer(return_ty, call_expression_tcx);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let parameters = self.signature.parameters();
|
let parameters = self.signature.parameters();
|
||||||
@@ -3289,6 +3304,23 @@ impl<'db> BindingError<'db> {
|
|||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Re-infer the argument type of call expressions, ignoring the type context for more
|
||||||
|
// precise error messages.
|
||||||
|
let provided_ty = match Self::get_argument_node(node, *argument_index) {
|
||||||
|
None => *provided_ty,
|
||||||
|
|
||||||
|
// Ignore starred arguments, as those are difficult to re-infer.
|
||||||
|
Some(
|
||||||
|
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||||
|
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }),
|
||||||
|
) => *provided_ty,
|
||||||
|
|
||||||
|
Some(
|
||||||
|
ast::ArgOrKeyword::Arg(value)
|
||||||
|
| ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }),
|
||||||
|
) => infer_isolated_expression(context.db(), context.scope(), value),
|
||||||
|
};
|
||||||
|
|
||||||
let provided_ty_display = provided_ty.display(context.db());
|
let provided_ty_display = provided_ty.display(context.db());
|
||||||
let expected_ty_display = expected_ty.display(context.db());
|
let expected_ty_display = expected_ty.display(context.db());
|
||||||
|
|
||||||
@@ -3624,22 +3656,29 @@ impl<'db> BindingError<'db> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_node(node: ast::AnyNodeRef, argument_index: Option<usize>) -> ast::AnyNodeRef {
|
fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option<usize>) -> ast::AnyNodeRef<'_> {
|
||||||
// If we have a Call node and an argument index, report the diagnostic on the correct
|
// If we have a Call node and an argument index, report the diagnostic on the correct
|
||||||
// argument node; otherwise, report it on the entire provided node.
|
// argument node; otherwise, report it on the entire provided node.
|
||||||
|
match Self::get_argument_node(node, argument_index) {
|
||||||
|
Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(),
|
||||||
|
Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(),
|
||||||
|
None => node,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_argument_node(
|
||||||
|
node: ast::AnyNodeRef<'_>,
|
||||||
|
argument_index: Option<usize>,
|
||||||
|
) -> Option<ArgOrKeyword<'_>> {
|
||||||
match (node, argument_index) {
|
match (node, argument_index) {
|
||||||
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => {
|
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some(
|
||||||
match call_node
|
call_node
|
||||||
.arguments
|
.arguments
|
||||||
.arguments_source_order()
|
.arguments_source_order()
|
||||||
.nth(argument_index)
|
.nth(argument_index)
|
||||||
.expect("argument index should not be out of range")
|
.expect("argument index should not be out of range"),
|
||||||
{
|
),
|
||||||
ast::ArgOrKeyword::Arg(expr) => expr.into(),
|
_ => None,
|
||||||
ast::ArgOrKeyword::Keyword(keyword) => keyword.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => node,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -40,6 +40,7 @@ pub(crate) struct InferContext<'db, 'ast> {
|
|||||||
module: &'ast ParsedModuleRef,
|
module: &'ast ParsedModuleRef,
|
||||||
diagnostics: std::cell::RefCell<TypeCheckDiagnostics>,
|
diagnostics: std::cell::RefCell<TypeCheckDiagnostics>,
|
||||||
no_type_check: InNoTypeCheck,
|
no_type_check: InNoTypeCheck,
|
||||||
|
multi_inference: bool,
|
||||||
bomb: DebugDropBomb,
|
bomb: DebugDropBomb,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
|
|||||||
scope,
|
scope,
|
||||||
module,
|
module,
|
||||||
file: scope.file(db),
|
file: scope.file(db),
|
||||||
|
multi_inference: false,
|
||||||
diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()),
|
diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()),
|
||||||
no_type_check: InNoTypeCheck::default(),
|
no_type_check: InNoTypeCheck::default(),
|
||||||
bomb: DebugDropBomb::new(
|
bomb: DebugDropBomb::new(
|
||||||
@@ -156,6 +158,18 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
|
|||||||
DiagnosticGuardBuilder::new(self, id, severity)
|
DiagnosticGuardBuilder::new(self, id, severity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the current expression is being inferred for a second
|
||||||
|
/// (or subsequent) time, with a potentially different bidirectional type
|
||||||
|
/// context.
|
||||||
|
pub(super) fn is_in_multi_inference(&self) -> bool {
|
||||||
|
self.multi_inference
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the multi-inference state, returning the previous value.
|
||||||
|
pub(super) fn set_multi_inference(&mut self, multi_inference: bool) -> bool {
|
||||||
|
std::mem::replace(&mut self.multi_inference, multi_inference)
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) {
|
pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) {
|
||||||
self.no_type_check = no_type_check;
|
self.no_type_check = no_type_check;
|
||||||
}
|
}
|
||||||
@@ -410,6 +424,11 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
|
|||||||
if ctx.is_in_no_type_check() {
|
if ctx.is_in_no_type_check() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
// If this lint is being reported as part of multi-inference of a given expression,
|
||||||
|
// silence it to avoid duplicated diagnostics.
|
||||||
|
if ctx.is_in_multi_inference() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let id = DiagnosticId::Lint(lint.name());
|
let id = DiagnosticId::Lint(lint.name());
|
||||||
|
|
||||||
let suppressions = suppressions(ctx.db(), ctx.file());
|
let suppressions = suppressions(ctx.db(), ctx.file());
|
||||||
@@ -575,6 +594,11 @@ impl<'db, 'ctx> DiagnosticGuardBuilder<'db, 'ctx> {
|
|||||||
if !ctx.db.should_check_file(ctx.file) {
|
if !ctx.db.should_check_file(ctx.file) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
// If this lint is being reported as part of multi-inference of a given expression,
|
||||||
|
// silence it to avoid duplicated diagnostics.
|
||||||
|
if ctx.is_in_multi_inference() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
Some(DiagnosticGuardBuilder { ctx, id, severity })
|
Some(DiagnosticGuardBuilder { ctx, id, severity })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1975,7 +1975,7 @@ pub(super) fn report_invalid_assignment<'db>(
|
|||||||
if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db())
|
if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db())
|
||||||
&& let Some(value) = annotated_assignment.value(context.module())
|
&& let Some(value) = annotated_assignment.value(context.module())
|
||||||
{
|
{
|
||||||
// Re-infer the RHS of the annotated assignment, ignoring the type context, for more precise
|
// Re-infer the RHS of the annotated assignment, ignoring the type context for more precise
|
||||||
// error messages.
|
// error messages.
|
||||||
source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value);
|
source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value);
|
||||||
}
|
}
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
use std::iter;
|
use std::{iter, mem};
|
||||||
|
|
||||||
use itertools::{Either, Itertools};
|
use itertools::{Either, Itertools};
|
||||||
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
|
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
|
||||||
@@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
|
|||||||
use crate::semantic_index::{
|
use crate::semantic_index::{
|
||||||
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
|
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
|
||||||
};
|
};
|
||||||
|
use crate::types::call::bind::MatchingOverloadIndex;
|
||||||
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
|
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
|
||||||
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
|
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
|
||||||
use crate::types::context::{InNoTypeCheck, InferContext};
|
use crate::types::context::{InNoTypeCheck, InferContext};
|
||||||
@@ -88,12 +89,13 @@ use crate::types::typed_dict::{
|
|||||||
};
|
};
|
||||||
use crate::types::visitor::any_over_type;
|
use crate::types::visitor::any_over_type;
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType,
|
CallDunderError, CallableBinding, CallableType, ClassLiteral, ClassType, DataclassParams,
|
||||||
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType,
|
||||||
MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType,
|
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
|
||||||
SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers,
|
Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type,
|
||||||
TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation,
|
TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers,
|
||||||
TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
|
TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind,
|
||||||
|
TypedDictType, UnionBuilder, UnionType, binding_type, todo_type,
|
||||||
};
|
};
|
||||||
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
|
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
|
||||||
use crate::unpack::{EvaluationMode, UnpackPosition};
|
use crate::unpack::{EvaluationMode, UnpackPosition};
|
||||||
@@ -257,6 +259,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
|
|||||||
/// is a stub file but we're still in a non-deferred region.
|
/// is a stub file but we're still in a non-deferred region.
|
||||||
deferred_state: DeferredExpressionState,
|
deferred_state: DeferredExpressionState,
|
||||||
|
|
||||||
|
multi_inference_state: MultiInferenceState,
|
||||||
|
|
||||||
/// For function definitions, the undecorated type of the function.
|
/// For function definitions, the undecorated type of the function.
|
||||||
undecorated_type: Option<Type<'db>>,
|
undecorated_type: Option<Type<'db>>,
|
||||||
|
|
||||||
@@ -287,10 +291,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
context: InferContext::new(db, scope, module),
|
context: InferContext::new(db, scope, module),
|
||||||
index,
|
index,
|
||||||
region,
|
region,
|
||||||
|
scope,
|
||||||
return_types_and_ranges: vec![],
|
return_types_and_ranges: vec![],
|
||||||
called_functions: FxHashSet::default(),
|
called_functions: FxHashSet::default(),
|
||||||
deferred_state: DeferredExpressionState::None,
|
deferred_state: DeferredExpressionState::None,
|
||||||
scope,
|
multi_inference_state: MultiInferenceState::Panic,
|
||||||
expressions: FxHashMap::default(),
|
expressions: FxHashMap::default(),
|
||||||
bindings: VecMap::default(),
|
bindings: VecMap::default(),
|
||||||
declarations: VecMap::default(),
|
declarations: VecMap::default(),
|
||||||
@@ -4911,6 +4916,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
self.infer_expression(expression, TypeContext::default())
|
self.infer_expression(expression, TypeContext::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Infer the argument types for a single binding.
|
||||||
fn infer_argument_types<'a>(
|
fn infer_argument_types<'a>(
|
||||||
&mut self,
|
&mut self,
|
||||||
ast_arguments: &ast::Arguments,
|
ast_arguments: &ast::Arguments,
|
||||||
@@ -4920,22 +4926,155 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
debug_assert!(
|
debug_assert!(
|
||||||
ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len()
|
ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len()
|
||||||
);
|
);
|
||||||
let iter = (arguments.iter_mut())
|
|
||||||
.zip(argument_forms.iter().copied())
|
let iter = itertools::izip!(
|
||||||
.zip(ast_arguments.arguments_source_order());
|
arguments.iter_mut(),
|
||||||
for (((_, argument_type), form), arg_or_keyword) in iter {
|
argument_forms.iter().copied(),
|
||||||
let argument = match arg_or_keyword {
|
ast_arguments.arguments_source_order()
|
||||||
// We already inferred the type of splatted arguments.
|
);
|
||||||
|
|
||||||
|
for ((_, argument_type), argument_form, ast_argument) in iter {
|
||||||
|
let argument = match ast_argument {
|
||||||
|
// Splatted arguments are inferred before parameter matching to
|
||||||
|
// determine their length.
|
||||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
||||||
|
|
||||||
ast::ArgOrKeyword::Arg(arg) => arg,
|
ast::ArgOrKeyword::Arg(arg) => arg,
|
||||||
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
||||||
};
|
};
|
||||||
let ty = self.infer_argument_type(argument, form, TypeContext::default());
|
|
||||||
|
let ty = self.infer_argument_type(argument, argument_form, TypeContext::default());
|
||||||
*argument_type = Some(ty);
|
*argument_type = Some(ty);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Infer the argument types for multiple potential bindings and overloads.
|
||||||
|
fn infer_all_argument_types<'a>(
|
||||||
|
&mut self,
|
||||||
|
ast_arguments: &ast::Arguments,
|
||||||
|
arguments: &mut CallArguments<'a, 'db>,
|
||||||
|
bindings: &Bindings<'db>,
|
||||||
|
) {
|
||||||
|
debug_assert!(
|
||||||
|
ast_arguments.len() == arguments.len()
|
||||||
|
&& arguments.len() == bindings.argument_forms().len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let iter = itertools::izip!(
|
||||||
|
0..,
|
||||||
|
arguments.iter_mut(),
|
||||||
|
bindings.argument_forms().iter().copied(),
|
||||||
|
ast_arguments.arguments_source_order()
|
||||||
|
);
|
||||||
|
|
||||||
|
let overloads_with_binding = bindings
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|binding| {
|
||||||
|
match binding.matching_overload_index() {
|
||||||
|
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
|
||||||
|
let overloads = binding
|
||||||
|
.matching_overloads()
|
||||||
|
.map(move |(_, overload)| (overload, binding));
|
||||||
|
|
||||||
|
Some(Either::Right(overloads))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is a single overload that does not match, we still infer the argument
|
||||||
|
// types for better diagnostics.
|
||||||
|
MatchingOverloadIndex::None => match binding.overloads() {
|
||||||
|
[overload] => Some(Either::Left(std::iter::once((overload, binding)))),
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.flatten();
|
||||||
|
|
||||||
|
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
|
||||||
|
let ast_argument = match ast_argument {
|
||||||
|
// Splatted arguments are inferred before parameter matching to
|
||||||
|
// determine their length.
|
||||||
|
//
|
||||||
|
// TODO: Re-infer splatted arguments with their type context.
|
||||||
|
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||||
|
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
||||||
|
|
||||||
|
ast::ArgOrKeyword::Arg(arg) => arg,
|
||||||
|
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Type-form arguments are inferred without type context, so we can infer the argument type directly.
|
||||||
|
if let Some(ParameterForm::Type) = argument_form {
|
||||||
|
*argument_type = Some(self.infer_type_expression(ast_argument));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve the parameter type for the current argument in a given overload and its binding.
|
||||||
|
let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| {
|
||||||
|
let argument_index = if binding.bound_type.is_some() {
|
||||||
|
argument_index + 1
|
||||||
|
} else {
|
||||||
|
argument_index
|
||||||
|
};
|
||||||
|
|
||||||
|
let argument_matches = &overload.argument_matches()[argument_index];
|
||||||
|
let [parameter_index] = argument_matches.parameters.as_slice() else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
overload.signature.parameters()[*parameter_index].annotated_type()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If there is only a single binding and overload, we can infer the argument directly with
|
||||||
|
// the unique parameter type annotation.
|
||||||
|
if let Ok((overload, binding)) = overloads_with_binding.clone().exactly_one() {
|
||||||
|
self.infer_expression_impl(
|
||||||
|
ast_argument,
|
||||||
|
TypeContext::new(parameter_type(overload, binding)),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Otherwise, each type is a valid independent inference of the given argument, and we may
|
||||||
|
// require different permutations of argument types to correctly perform argument expansion
|
||||||
|
// during overload evaluation, so we take the intersection of all the types we inferred for
|
||||||
|
// each argument.
|
||||||
|
//
|
||||||
|
// Note that this applies to all nested expressions within each argument.
|
||||||
|
let old_multi_inference_state = mem::replace(
|
||||||
|
&mut self.multi_inference_state,
|
||||||
|
MultiInferenceState::Intersect,
|
||||||
|
);
|
||||||
|
|
||||||
|
// We perform inference once without any type context, emitting any diagnostics that are unrelated
|
||||||
|
// to bidirectional type inference.
|
||||||
|
self.infer_expression_impl(ast_argument, TypeContext::default());
|
||||||
|
|
||||||
|
// We then silence any diagnostics emitted during multi-inference, as the type context is only
|
||||||
|
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
|
||||||
|
// for non-matching overloads.
|
||||||
|
let was_in_multi_inference = self.context.set_multi_inference(true);
|
||||||
|
|
||||||
|
// Infer the type of each argument once with each distinct parameter type as type context.
|
||||||
|
let parameter_types = overloads_with_binding
|
||||||
|
.clone()
|
||||||
|
.filter_map(|(overload, binding)| parameter_type(overload, binding))
|
||||||
|
.collect::<FxHashSet<_>>();
|
||||||
|
|
||||||
|
for parameter_type in parameter_types {
|
||||||
|
self.infer_expression_impl(
|
||||||
|
ast_argument,
|
||||||
|
TypeContext::new(Some(parameter_type)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore the multi-inference state.
|
||||||
|
self.multi_inference_state = old_multi_inference_state;
|
||||||
|
self.context.set_multi_inference(was_in_multi_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
*argument_type = self.try_expression_type(ast_argument);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_argument_type(
|
fn infer_argument_type(
|
||||||
&mut self,
|
&mut self,
|
||||||
ast_argument: &ast::Expr,
|
ast_argument: &ast::Expr,
|
||||||
@@ -4956,6 +5095,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
expression.map(|expr| self.infer_expression(expr, tcx))
|
expression.map(|expr| self.infer_expression(expr, tcx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_or_infer_expression(
|
||||||
|
&mut self,
|
||||||
|
expression: &ast::Expr,
|
||||||
|
tcx: TypeContext<'db>,
|
||||||
|
) -> Type<'db> {
|
||||||
|
self.try_expression_type(expression)
|
||||||
|
.unwrap_or_else(|| self.infer_expression(expression, tcx))
|
||||||
|
}
|
||||||
|
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> {
|
fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> {
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
@@ -5016,6 +5164,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
types.expression_type(expression)
|
types.expression_type(expression)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Infer the type of an expression.
|
||||||
fn infer_expression_impl(
|
fn infer_expression_impl(
|
||||||
&mut self,
|
&mut self,
|
||||||
expression: &ast::Expr,
|
expression: &ast::Expr,
|
||||||
@@ -5051,7 +5200,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
ast::Expr::Compare(compare) => self.infer_compare_expression(compare),
|
ast::Expr::Compare(compare) => self.infer_compare_expression(compare),
|
||||||
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
|
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
|
||||||
ast::Expr::Slice(slice) => self.infer_slice_expression(slice),
|
ast::Expr::Slice(slice) => self.infer_slice_expression(slice),
|
||||||
ast::Expr::Named(named) => self.infer_named_expression(named),
|
|
||||||
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
|
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
|
||||||
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
|
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
|
||||||
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
|
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
|
||||||
@@ -5059,6 +5207,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
|
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
|
||||||
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
|
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
|
||||||
ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),
|
ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),
|
||||||
|
ast::Expr::Named(named) => {
|
||||||
|
// Definitions must be unique, so we bypass multi-inference for named expressions.
|
||||||
|
if !self.multi_inference_state.is_panic()
|
||||||
|
&& let Some(ty) = self.expressions.get(&expression.into())
|
||||||
|
{
|
||||||
|
return *ty;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.infer_named_expression(named)
|
||||||
|
}
|
||||||
ast::Expr::IpyEscapeCommand(_) => {
|
ast::Expr::IpyEscapeCommand(_) => {
|
||||||
todo_type!("Ipy escape command support")
|
todo_type!("Ipy escape command support")
|
||||||
}
|
}
|
||||||
@@ -5068,6 +5226,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
|
|
||||||
ty
|
ty
|
||||||
}
|
}
|
||||||
|
|
||||||
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
|
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
|
||||||
if self.deferred_state.in_string_annotation() {
|
if self.deferred_state.in_string_annotation() {
|
||||||
// Avoid storing the type of expressions that are part of a string annotation because
|
// Avoid storing the type of expressions that are part of a string annotation because
|
||||||
@@ -5075,8 +5234,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
// on the string expression itself that represents the annotation.
|
// on the string expression itself that represents the annotation.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let previous = self.expressions.insert(expression.into(), ty);
|
|
||||||
assert_eq!(previous, None);
|
let db = self.db();
|
||||||
|
|
||||||
|
match self.multi_inference_state {
|
||||||
|
MultiInferenceState::Panic => {
|
||||||
|
let previous = self.expressions.insert(expression.into(), ty);
|
||||||
|
assert_eq!(previous, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiInferenceState::Intersect => {
|
||||||
|
self.expressions
|
||||||
|
.entry(expression.into())
|
||||||
|
.and_modify(|current| {
|
||||||
|
*current = IntersectionType::from_elements(db, [*current, ty]);
|
||||||
|
})
|
||||||
|
.or_insert(ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
|
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
|
||||||
@@ -5297,31 +5472,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
} = dict;
|
} = dict;
|
||||||
|
|
||||||
// Validate `TypedDict` dictionary literal assignments.
|
// Validate `TypedDict` dictionary literal assignments.
|
||||||
if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) {
|
if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict)
|
||||||
let typed_dict_items = typed_dict.items(self.db());
|
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
|
||||||
|
{
|
||||||
for item in items {
|
return ty;
|
||||||
self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
|
|
||||||
|
|
||||||
if let Some(ast::Expr::StringLiteral(ref key)) = item.key
|
|
||||||
&& let Some(key) = key.as_single_part_string()
|
|
||||||
&& let Some(field) = typed_dict_items.get(key.as_str())
|
|
||||||
{
|
|
||||||
self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty)));
|
|
||||||
} else {
|
|
||||||
self.infer_expression(&item.value, TypeContext::default());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
validate_typed_dict_dict_literal(
|
|
||||||
&self.context,
|
|
||||||
typed_dict,
|
|
||||||
dict,
|
|
||||||
dict.into(),
|
|
||||||
|expr| self.expression_type(expr),
|
|
||||||
);
|
|
||||||
|
|
||||||
return Type::TypedDict(typed_dict);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid false positives for the functional `TypedDict` form, which is currently
|
// Avoid false positives for the functional `TypedDict` form, which is currently
|
||||||
@@ -5342,6 +5496,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_typed_dict_expression(
|
||||||
|
&mut self,
|
||||||
|
dict: &ast::ExprDict,
|
||||||
|
typed_dict: TypedDictType<'db>,
|
||||||
|
) -> Option<Type<'db>> {
|
||||||
|
let ast::ExprDict {
|
||||||
|
range: _,
|
||||||
|
node_index: _,
|
||||||
|
items,
|
||||||
|
} = dict;
|
||||||
|
|
||||||
|
let typed_dict_items = typed_dict.items(self.db());
|
||||||
|
|
||||||
|
for item in items {
|
||||||
|
self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
|
||||||
|
|
||||||
|
if let Some(ast::Expr::StringLiteral(ref key)) = item.key
|
||||||
|
&& let Some(key) = key.as_single_part_string()
|
||||||
|
&& let Some(field) = typed_dict_items.get(key.as_str())
|
||||||
|
{
|
||||||
|
self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty)));
|
||||||
|
} else {
|
||||||
|
self.infer_expression(&item.value, TypeContext::default());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
validate_typed_dict_dict_literal(&self.context, typed_dict, dict, dict.into(), |expr| {
|
||||||
|
self.expression_type(expr)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
.map(|_| Type::TypedDict(typed_dict))
|
||||||
|
}
|
||||||
|
|
||||||
// Infer the type of a collection literal expression.
|
// Infer the type of a collection literal expression.
|
||||||
fn infer_collection_literal<'expr, const N: usize>(
|
fn infer_collection_literal<'expr, const N: usize>(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -5399,7 +5586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
for elts in elts {
|
for elts in elts {
|
||||||
// An unpacking expression for a dictionary.
|
// An unpacking expression for a dictionary.
|
||||||
if let &[None, Some(value)] = elts.as_slice() {
|
if let &[None, Some(value)] = elts.as_slice() {
|
||||||
let inferred_value_ty = self.infer_expression(value, TypeContext::default());
|
let inferred_value_ty = self.get_or_infer_expression(value, TypeContext::default());
|
||||||
|
|
||||||
// Merge the inferred type of the nested dictionary.
|
// Merge the inferred type of the nested dictionary.
|
||||||
if let Some(specialization) =
|
if let Some(specialization) =
|
||||||
@@ -5420,9 +5607,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
// The inferred type of each element acts as an additional constraint on `T`.
|
// The inferred type of each element acts as an additional constraint on `T`.
|
||||||
for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone())
|
for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone())
|
||||||
{
|
{
|
||||||
let Some(inferred_elt_ty) = self.infer_optional_expression(elt, elt_tcx) else {
|
let Some(elt) = elt else { continue };
|
||||||
continue;
|
|
||||||
};
|
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
|
||||||
|
|
||||||
// Convert any element literals to their promoted type form to avoid excessively large
|
// Convert any element literals to their promoted type form to avoid excessively large
|
||||||
// unions for large nested list literals, which the constraint solver struggles with.
|
// unions for large nested list literals, which the constraint solver struggles with.
|
||||||
@@ -5967,7 +6154,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
let bindings = callable_type
|
let bindings = callable_type
|
||||||
.bindings(self.db())
|
.bindings(self.db())
|
||||||
.match_parameters(self.db(), &call_arguments);
|
.match_parameters(self.db(), &call_arguments);
|
||||||
self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms());
|
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
||||||
|
|
||||||
// Validate `TypedDict` constructor calls after argument type inference
|
// Validate `TypedDict` constructor calls after argument type inference
|
||||||
if let Some(class_literal) = callable_type.into_class_literal() {
|
if let Some(class_literal) = callable_type.into_class_literal() {
|
||||||
@@ -9087,6 +9274,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
// builder only state
|
// builder only state
|
||||||
typevar_binding_context: _,
|
typevar_binding_context: _,
|
||||||
deferred_state: _,
|
deferred_state: _,
|
||||||
|
multi_inference_state: _,
|
||||||
called_functions: _,
|
called_functions: _,
|
||||||
index: _,
|
index: _,
|
||||||
region: _,
|
region: _,
|
||||||
@@ -9149,6 +9337,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
// builder only state
|
// builder only state
|
||||||
typevar_binding_context: _,
|
typevar_binding_context: _,
|
||||||
deferred_state: _,
|
deferred_state: _,
|
||||||
|
multi_inference_state: _,
|
||||||
called_functions: _,
|
called_functions: _,
|
||||||
index: _,
|
index: _,
|
||||||
region: _,
|
region: _,
|
||||||
@@ -9220,6 +9409,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||||||
// Builder only state
|
// Builder only state
|
||||||
typevar_binding_context: _,
|
typevar_binding_context: _,
|
||||||
deferred_state: _,
|
deferred_state: _,
|
||||||
|
multi_inference_state: _,
|
||||||
called_functions: _,
|
called_functions: _,
|
||||||
index: _,
|
index: _,
|
||||||
region: _,
|
region: _,
|
||||||
@@ -9265,6 +9455,26 @@ impl GenericContextError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Dictates the behavior when an expression is inferred multiple times.
|
||||||
|
#[derive(Default, Debug, Clone, Copy)]
|
||||||
|
enum MultiInferenceState {
|
||||||
|
/// Panic if the expression has already been inferred.
|
||||||
|
#[default]
|
||||||
|
Panic,
|
||||||
|
|
||||||
|
/// Store the intersection of all types inferred for the expression.
|
||||||
|
Intersect,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiInferenceState {
|
||||||
|
fn is_panic(self) -> bool {
|
||||||
|
match self {
|
||||||
|
MultiInferenceState::Panic => true,
|
||||||
|
MultiInferenceState::Intersect => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The deferred state of a specific expression in an inference region.
|
/// The deferred state of a specific expression in an inference region.
|
||||||
#[derive(Default, Debug, Clone, Copy)]
|
#[derive(Default, Debug, Clone, Copy)]
|
||||||
enum DeferredExpressionState {
|
enum DeferredExpressionState {
|
||||||
@@ -9538,7 +9748,7 @@ impl<K, V> Default for VecMap<K, V> {
|
|||||||
|
|
||||||
/// Set based on a `Vec`. It doesn't enforce
|
/// Set based on a `Vec`. It doesn't enforce
|
||||||
/// uniqueness on insertion. Instead, it relies on the caller
|
/// uniqueness on insertion. Instead, it relies on the caller
|
||||||
/// that elements are uniuqe. For example, the way we visit definitions
|
/// that elements are unique. For example, the way we visit definitions
|
||||||
/// in the `TypeInference` builder make already implicitly guarantees that each definition
|
/// in the `TypeInference` builder make already implicitly guarantees that each definition
|
||||||
/// is only visited once.
|
/// is only visited once.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
@@ -132,7 +132,8 @@ impl TypedDictAssignmentKind {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Validates assignment of a value to a specific key on a `TypedDict`.
|
/// Validates assignment of a value to a specific key on a `TypedDict`.
|
||||||
/// Returns true if the assignment is valid, false otherwise.
|
///
|
||||||
|
/// Returns true if the assignment is valid, or false otherwise.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
||||||
context: &InferContext<'db, 'ast>,
|
context: &InferContext<'db, 'ast>,
|
||||||
@@ -157,6 +158,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||||||
Type::string_literal(db, key),
|
Type::string_literal(db, key),
|
||||||
&items,
|
&items,
|
||||||
);
|
);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -240,13 +242,16 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Validates that all required keys are provided in a `TypedDict` construction.
|
/// Validates that all required keys are provided in a `TypedDict` construction.
|
||||||
|
///
|
||||||
/// Reports errors for any keys that are required but not provided.
|
/// Reports errors for any keys that are required but not provided.
|
||||||
|
///
|
||||||
|
/// Returns true if the assignment is valid, or false otherwise.
|
||||||
pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
||||||
context: &InferContext<'db, 'ast>,
|
context: &InferContext<'db, 'ast>,
|
||||||
typed_dict: TypedDictType<'db>,
|
typed_dict: TypedDictType<'db>,
|
||||||
provided_keys: &OrderSet<&str>,
|
provided_keys: &OrderSet<&str>,
|
||||||
error_node: AnyNodeRef<'ast>,
|
error_node: AnyNodeRef<'ast>,
|
||||||
) {
|
) -> bool {
|
||||||
let db = context.db();
|
let db = context.db();
|
||||||
let items = typed_dict.items(db);
|
let items = typed_dict.items(db);
|
||||||
|
|
||||||
@@ -255,7 +260,12 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
|||||||
.filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str()))
|
.filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for missing_key in required_keys.difference(provided_keys) {
|
let missing_keys = required_keys.difference(provided_keys);
|
||||||
|
|
||||||
|
let mut has_missing_key = false;
|
||||||
|
for missing_key in missing_keys {
|
||||||
|
has_missing_key = true;
|
||||||
|
|
||||||
report_missing_typed_dict_key(
|
report_missing_typed_dict_key(
|
||||||
context,
|
context,
|
||||||
error_node,
|
error_node,
|
||||||
@@ -263,6 +273,8 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
|||||||
missing_key,
|
missing_key,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
!has_missing_key
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
|
pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
|
||||||
@@ -373,7 +385,7 @@ fn validate_from_keywords<'db, 'ast>(
|
|||||||
provided_keys
|
provided_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates a `TypedDict` dictionary literal assignment
|
/// Validates a `TypedDict` dictionary literal assignment,
|
||||||
/// e.g. `person: Person = {"name": "Alice", "age": 30}`
|
/// e.g. `person: Person = {"name": "Alice", "age": 30}`
|
||||||
pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
||||||
context: &InferContext<'db, 'ast>,
|
context: &InferContext<'db, 'ast>,
|
||||||
@@ -381,7 +393,8 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
|||||||
dict_expr: &'ast ast::ExprDict,
|
dict_expr: &'ast ast::ExprDict,
|
||||||
error_node: AnyNodeRef<'ast>,
|
error_node: AnyNodeRef<'ast>,
|
||||||
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
|
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
|
||||||
) -> OrderSet<&'ast str> {
|
) -> Result<OrderSet<&'ast str>, OrderSet<&'ast str>> {
|
||||||
|
let mut valid = true;
|
||||||
let mut provided_keys = OrderSet::new();
|
let mut provided_keys = OrderSet::new();
|
||||||
|
|
||||||
// Validate each key-value pair in the dictionary literal
|
// Validate each key-value pair in the dictionary literal
|
||||||
@@ -392,7 +405,8 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
|||||||
provided_keys.insert(key_str);
|
provided_keys.insert(key_str);
|
||||||
|
|
||||||
let value_type = expression_type_fn(&item.value);
|
let value_type = expression_type_fn(&item.value);
|
||||||
validate_typed_dict_key_assignment(
|
|
||||||
|
valid &= validate_typed_dict_key_assignment(
|
||||||
context,
|
context,
|
||||||
typed_dict,
|
typed_dict,
|
||||||
key_str,
|
key_str,
|
||||||
@@ -406,7 +420,11 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
|
valid &= validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
|
||||||
|
|
||||||
provided_keys
|
if valid {
|
||||||
|
Ok(provided_keys)
|
||||||
|
} else {
|
||||||
|
Err(provided_keys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user