diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 423ad21ae4ac..1b24e8210408 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -37,10 +37,7 @@ private module Input1 implements InputSig1 { class Type = T::Type; - predicate isPseudoType(Type t) { - t instanceof UnknownType or - t instanceof NeverType - } + class UnknownType = T::UnknownType; class TypeParameter = T::TypeParameter; @@ -276,6 +273,382 @@ private module M2 = Make2; import M2 +private module Input3 implements InputSig3 { + private import rust as Rust + + predicate cachedStageRevRef() { + (implicitDerefChainBorrow(_, _, _) implies any()) + or + (exists(resolveCallTarget(_, _)) implies any()) + or + (exists(resolveStructFieldExpr(_, _)) implies any()) + or + (exists(resolveTupleFieldExpr(_, _)) implies any()) + } + + predicate inferType = M3::inferType/2; + + class BoolType extends DataType { + BoolType() { this.getTypeItem() instanceof Builtins::Bool } + } + + class AstNode = Rust::AstNode; + + TypeMention getTypeAnnotation(AstNode n) { + exists(LetStmt let | + n = let.getPat() and + result = let.getTypeRepr() + ) + or + result = n.(SelfParam).getTypeRepr() + or + exists(Param p | + n = p.getPat() and + result = p.getTypeRepr() + ) + or + result = n.(ShorthandSelfParameterMention) + } + + class Expr = Rust::Expr; + + class SwitchExpr extends Rust::MatchExpr { + Expr getExpr() { result = this.getScrutinee() } + + Case getCase(int index) { result = this.getArm(index) } + } + + class Case extends Rust::MatchArm { + AstNode getAPattern() { result = this.getPat() } + + Expr getBody() { result = this.getExpr() } + } + + class ConditionalExpr extends IfExpr { + Expr getThen() { result = super.getThen() } + } + + class BinaryExpr extends Rust::BinaryExpr { + Expr getLeftOperand() { result = super.getLhs() } + + Expr getRightOperand() { result = super.getRhs() } + } + + class LogicalAndExpr extends BinaryExpr, Rust::LogicalAndExpr { } + + class LogicalOrExpr extends BinaryExpr, Rust::LogicalOrExpr { } + + abstract class Assignment extends BinaryExpr { } + + class AssignExpr extends Assignment, Rust::AssignmentExpr { } + + class ParenExpr = Rust::ParenExpr; + + class Variable extends Rust::Variable { + AstNode getDefiningNode() { + result = this.getPat().getName() or + result = this.getParameter().(SelfParam) + } + + Expr getAnAccess() { result = super.getAnAccess() } + } + + abstract class LetDeclaration extends AstNode { + abstract predicate isCoercionSite(); + + abstract AstNode getLeftOperand(); + + abstract AstNode getRightOperand(); + } + + private class LetExprLetDeclaration extends LetDeclaration, LetExpr { + override predicate isCoercionSite() { not this.getPat() instanceof IdentPat } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getScrutinee() } + } + + private class LetStmtLetDeclaration extends LetDeclaration, LetStmt { + override predicate isCoercionSite() { + this.hasTypeRepr() or + not identLetStmt(this, _, _) + } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getInitializer() } + } + + class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment; + + class TypePosition = FunctionPosition; + + class Callable extends FunctionCallMatchingInput::Declaration { + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) { + result = + tp.(TypeParamTypeParameter) + .getTypeParam() + .getAdditionalTypeBound(this.getFunction(), _) + .getTypeRepr() + } + } + + class Call extends Expr instanceof FunctionCallMatchingInput::Access { + Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { + result = super.getTypeArgument(apos, path) + } + + AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) } + + /** Gets the target of this call. */ + Callable getTargetCertain() { + exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p | + result.isFunction(i, f) and + p = CallExprImpl::getFunctionPath(this) and + f = resolvePath(p) and + f.isDirectlyFor(i) + ) + } + + Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) } + } + + bindingset[derefChainBorrow] + Type inferCallTypeIn(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) { + result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path) + } + + Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path) { + result = inferFunctionCallTypeNonSelf(n, pos, path) + or + exists(FunctionCallMatchingInput::Access a | + result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and + if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() + then not path.isEmpty() + else any() + ) + } + + predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + n1 = + any(IdentPat ip | + n2 = ip.getName() and + prefix1.isEmpty() and + if ip.isRef() + then + exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + else prefix2.isEmpty() + ) + or + exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | + n1 = dce.getArgList() and + tt.getArity() = dce.getNumberOfSyntacticArguments() and + n2 = dce.getSyntacticPositionalArgument(i) and + prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and + prefix2.isEmpty() + ) + or + exists(ClosureExpr ce, int index | + n1 = ce and + n2 = ce.getParam(index).getPat() and + prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and + prefix2.isEmpty() + ) + } + + Type inferCertainTypeInput(AstNode n, TypePath path) { + result = inferFunctionBodyType(n, path) + or + result = inferLiteralType(n, path, true) + or + result = inferRefPatType(n) and + path.isEmpty() + or + result = inferRefExprType(n) and + path.isEmpty() + or + result = inferCertainStructExprType(n, path) + or + result = inferCertainStructPatType(n, path) + or + result = inferRangeExprType(n) and + path.isEmpty() + or + result = inferTupleRootType(n) and + path.isEmpty() + or + result = inferBlockExprType(n, path) + or + result = inferArrayExprType(n) and + path.isEmpty() + or + result = inferCastExprType(n, path) + or + exprHasUnitType(n) and + path.isEmpty() and + result instanceof UnitType + or + isPanicMacroCall(n) and + path.isEmpty() and + result instanceof NeverType + or + n instanceof ClosureExpr and + path.isEmpty() and + result = closureRootType() + } + + predicate typeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + prefix1.isEmpty() and + prefix2.isEmpty() and + ( + exists(MatchExpr me | + n1 = me.getScrutinee() and + n2 = me.getAnArm().getPat() + ) + or + n1 = n2.(OrPat).getAPat() + or + n1 = n2.(ParenPat).getPat() + or + n1 = n2.(LiteralPat).getLiteral() + or + exists(BreakExpr break | + break.getExpr() = n1 and + break.getTarget() = n2.(LoopExpr) + ) + or + n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and + not isPanicMacroCall(n2) + or + n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() + ) + or + n2 = + any(RefExpr re | + n1 = re.getExpr() and + prefix1.isEmpty() and + prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) + ) + or + n2 = + any(RefPat rp | + n1 = rp.getPat() and + prefix1.isEmpty() and + exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + ) + or + exists(int i, int arity | + prefix1.isEmpty() and + prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) + | + arity = n2.(TupleExpr).getNumberOfFields() and + n1 = n2.(TupleExpr).getField(i) + or + arity = n2.(TuplePat).getTupleArity() and + n1 = n2.(TuplePat).getField(i) + ) + or + exists(BlockExpr be | + n1 = be and + n2 = be.getStmtList().getTailExpr() and + if be.isAsync() + then + prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and + prefix2.isEmpty() + else ( + prefix1.isEmpty() and + prefix2.isEmpty() + ) + ) + or + // an array repeat expression (`[1; 3]`) has the type of the repeat operand + n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and + prefix1 = TypePath::singleton(getArrayTypeParameter()) and + prefix2.isEmpty() + } + + predicate typeEqualityAsymmetricInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // When `n2` is `*n1` propagate type information from a raw pointer type + // parameter at `n1`. The other direction is handled in + // `inferDereferencedExprPtrType`. + n1 = n2.(DerefExpr).getExpr() and + prefix1 = TypePath::singleton(getPtrTypeParameter()) and + prefix2.isEmpty() + or + n2 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n1) and + prefix2 = closureReturnPath() and + prefix1.isEmpty() + } + + /** + * Holds if `child` is a child of `parent`, and the Rust compiler applies [least + * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of + * `child`. + * + * In this case, we want type information to only flow from `child` to `parent`, + * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial + * explosion in inferred types. + * + * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound + */ + predicate parentChildType(AstNode parent, AstNode child, TypePath prefix) { + child = parent.(IfExpr).getABranch() and + prefix.isEmpty() + or + parent = any(MatchExpr me | child = me.getAnArm().getExpr()) and + prefix.isEmpty() + or + parent = any(ArrayListExpr ale | child = ale.getAnExpr()) and + prefix = TypePath::singleton(getArrayTypeParameter()) + or + bodyReturns(parent, child) and + prefix.isEmpty() + or + exists(Struct s | + child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and + prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and + s = getRangeType(parent) + ) + } + + Type inferTypeInput(AstNode n, TypePath path) { + result = inferAssignmentOperationType(n, path) + or + result = inferConstructionType(n, path) + or + result = inferOperationType(n, path) + or + result = inferFieldExprType(n, path) + or + result = inferTryExprType(n, path) + or + result = inferLiteralType(n, path, false) + or + result = inferAwaitExprType(n, path) + or + result = inferDereferencedExprPtrType(n, path) + or + result = inferForLoopExprType(n, path) + or + result = inferClosureExprType(n, path) + or + result = inferArgList(n, path) + or + result = inferDeconstructionPatType(n, path) + or + result = inferUnknownTypeFromAnnotation(n, path) + } +} + +private module M3 = Make3; + +import M3 + module Consistency { import M2::Consistency @@ -408,16 +781,38 @@ private class AssocFunctionDeclaration extends FunctionDeclaration { } pragma[nomagic] -private TypeMention getCallExprTypeMentionArgument(CallExpr ce, TypeArgumentPosition apos) { - exists(Path p, int i | p = CallExprImpl::getFunctionPath(ce) | - apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i)) and - result = getPathTypeArgument(p, pragma[only_bind_into](i)) +private TypePath getPathToImplSelfTypeParam(TypeParam tp) { + exists(ImplItemNode impl | + tp = impl.getTypeParam(_) and + TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) ) } pragma[nomagic] private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, TypePath path) { - result = getCallExprTypeMentionArgument(ce, apos).getTypeAt(path) + exists(Path p, ItemNode resolved, TypeParam tp | + p = CallExprImpl::getFunctionPath(ce) and + resolved = resolvePath(p) and + apos.asTypeParam() = tp + | + // For type parameters of the function we must resolve their + // instantiation from the path. For instance, for `fn bar(a: A) -> A` + // and the path `bar`, we must resolve `A` to `i64`. + exists(int i | + tp = resolved.getTypeParam(pragma[only_bind_into](i)) and + result = getPathTypeArgument(p, pragma[only_bind_into](i)).getTypeAt(path) + ) + or + // For type parameters of the `impl` block we must resolve their + // instantiation from the path. For instance, for `impl for Foo` + // and the path `Foo::bar` we must resolve `A` to `i64`. + exists(ImplItemNode impl, TypePath pathToTp | + resolved = impl.getASuccessor(_) and + tp = impl.getTypeParam(_) and + pathToTp = getPathToImplSelfTypeParam(tp) and + result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(path)) + ) + ) or // Handle constructions that use `Self(...)` syntax exists(Path p, TypePath path0 | @@ -427,29 +822,6 @@ private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, Typ ) } -/** Gets the type annotation that applies to `n`, if any. */ -private TypeMention getTypeAnnotation(AstNode n) { - exists(LetStmt let | - n = let.getPat() and - result = let.getTypeRepr() - ) - or - result = n.(SelfParam).getTypeRepr() - or - exists(Param p | - n = p.getPat() and - result = p.getTypeRepr() - ) -} - -/** Gets the type of `n`, which has an explicit type annotation. */ -pragma[nomagic] -private Type inferAnnotatedType(AstNode n, TypePath path) { - result = getTypeAnnotation(n).getTypeAt(path) - or - result = n.(ShorthandSelfParameterMention).getTypeAt(path) -} - pragma[nomagic] private Type inferFunctionBodyType(AstNode n, TypePath path) { exists(Function f | @@ -465,284 +837,54 @@ private Type inferFunctionBodyType(AstNode n, TypePath path) { * Holds if `me` is a call to the `panic!` macro. * * `panic!` needs special treatment, because it expands to a block expression - * that looks like it should have type `()` instead of the correct `!` type. - */ -pragma[nomagic] -private predicate isPanicMacroCall(MacroExpr me) { - me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic" -} - -// Due to "binding modes" the type of the pattern is not necessarily the -// same as the type of the initializer. However, when the pattern is an -// identifier pattern, its type is guaranteed to be the same as the type of the -// initializer. -private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) { - let.getPat() = lhs and - let.getInitializer() = rhs -} - -/** - * Gets the root type of a closure. - * - * We model closures as `dyn Fn` trait object types. A closure might implement - * only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others, - * giving closures the type `dyn Fn` works well in practice -- even if not - * entirely accurate. - */ -private DynTraitType closureRootType() { - result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs` -} - -/** Gets the path to a closure's return type. */ -private TypePath closureReturnPath() { - result = - TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType())) -} - -/** Gets the path to a closure's `index`th parameter type, where the arity is `arity`. */ -pragma[nomagic] -private TypePath closureParameterPath(int arity, int index) { - result = - TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()), - TypePath::singleton(getTupleTypeParameter(arity, index))) -} - -/** Module for inferring certain type information. */ -module CertainTypeInference { - pragma[nomagic] - private predicate callResolvesTo(CallExpr ce, Path p, Function f) { - p = CallExprImpl::getFunctionPath(ce) and - f = resolvePath(p) - } - - pragma[nomagic] - private Type getCallExprType(CallExpr ce, Path p, FunctionDeclaration f, TypePath path) { - exists(ImplOrTraitItemNodeOption i | - callResolvesTo(ce, p, f) and - result = f.getReturnType(i, path) and - f.isDirectlyFor(i) - ) - } - - pragma[nomagic] - private Type getCertainCallExprType(CallExpr ce, Path p, TypePath tp) { - forex(Function f | callResolvesTo(ce, p, f) | result = getCallExprType(ce, p, f, tp)) - } - - pragma[nomagic] - private TypePath getPathToImplSelfTypeParam(TypeParam tp) { - exists(ImplItemNode impl | - tp = impl.getTypeParam(_) and - TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) - ) - } - - pragma[nomagic] - private Type inferCertainCallExprType(CallExpr ce, TypePath path) { - exists(Type ty, TypePath prefix, Path p | ty = getCertainCallExprType(ce, p, prefix) | - exists(TypePath suffix, TypeParam tp | - tp = ty.(TypeParamTypeParameter).getTypeParam() and - path = prefix.append(suffix) - | - // For type parameters of the `impl` block we must resolve their - // instantiation from the path. For instance, for `impl for Foo` - // and the path `Foo::bar` we must resolve `A` to `i64`. - exists(TypePath pathToTp | - pathToTp = getPathToImplSelfTypeParam(tp) and - result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(suffix)) - ) - or - // For type parameters of the function we must resolve their - // instantiation from the path. For instance, for `fn bar(a: A) -> A` - // and the path `bar`, we must resolve `A` to `i64`. - result = getCallExprTypeArgument(ce, TTypeParamTypeArgumentPosition(tp), suffix) - ) - or - not ty instanceof TypeParameter and - result = ty and - path = prefix - ) - } - - private Type inferCertainStructExprType(StructExpr se, TypePath path) { - result = se.getPath().(TypeMention).getTypeAt(path) - } - - private Type inferCertainStructPatType(StructPat sp, TypePath path) { - result = sp.getPath().(TypeMention).getTypeAt(path) - } - - predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - exists(Variable v | n1 = v.getAnAccess() | - n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam) - ) - or - // A `let` statement with a type annotation is a coercion site and hence - // is not a certain type equality. - exists(LetStmt let | - not let.hasTypeRepr() and - identLetStmt(let, n1, n2) - ) - or - exists(LetExpr let | - // Similarly as for let statements, we need to rule out binding modes - // changing the type. - let.getPat().(IdentPat) = n1 and - let.getScrutinee() = n2 - ) - or - n1 = n2.(ParenExpr).getExpr() - ) - or - n1 = - any(IdentPat ip | - n2 = ip.getName() and - prefix1.isEmpty() and - if ip.isRef() - then - exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - else prefix2.isEmpty() - ) - or - exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | - n1 = dce.getArgList() and - tt.getArity() = dce.getNumberOfSyntacticArguments() and - n2 = dce.getSyntacticPositionalArgument(i) and - prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and - prefix2.isEmpty() - ) - or - exists(ClosureExpr ce, int index | - n1 = ce and - n2 = ce.getParam(index).getPat() and - prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and - prefix2.isEmpty() - ) - } - - pragma[nomagic] - private Type inferCertainTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferCertainType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - certainTypeEquality(n, prefix1, n2, prefix2) - or - certainTypeEquality(n2, prefix2, n, prefix1) - ) - } - - /** - * Holds if `n` has complete and certain type information and if `n` has the - * resulting type at `path`. - */ - cached - Type inferCertainType(AstNode n, TypePath path) { - result = inferAnnotatedType(n, path) and - Stages::TypeInferenceStage::ref() - or - result = inferFunctionBodyType(n, path) - or - result = inferCertainCallExprType(n, path) - or - result = inferCertainTypeEquality(n, path) - or - result = inferLiteralType(n, path, true) - or - result = inferRefPatType(n) and - path.isEmpty() - or - result = inferRefExprType(n) and - path.isEmpty() - or - result = inferLogicalOperationType(n, path) - or - result = inferCertainStructExprType(n, path) - or - result = inferCertainStructPatType(n, path) - or - result = inferRangeExprType(n) and - path.isEmpty() - or - result = inferTupleRootType(n) and - path.isEmpty() - or - result = inferBlockExprType(n, path) - or - result = inferArrayExprType(n) and - path.isEmpty() - or - result = inferCastExprType(n, path) - or - exprHasUnitType(n) and - path.isEmpty() and - result instanceof UnitType - or - isPanicMacroCall(n) and - path.isEmpty() and - result instanceof NeverType - or - n instanceof ClosureExpr and - path.isEmpty() and - result = closureRootType() - or - infersCertainTypeAt(n, path, result.getATypeParameter()) - } + * that looks like it should have type `()` instead of the correct `!` type. + */ +pragma[nomagic] +private predicate isPanicMacroCall(MacroExpr me) { + me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic" +} - /** - * Holds if `n` has complete and certain type information at the type path - * `prefix.tp`. This entails that the type at `prefix` must be the type - * that declares `tp`. - */ - pragma[nomagic] - private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { - exists(TypePath path | - exists(inferCertainType(n, path)) and - path.isSnoc(prefix, tp) - ) - } +// Due to "binding modes" the type of the pattern is not necessarily the +// same as the type of the initializer. However, when the pattern is an +// identifier pattern, its type is guaranteed to be the same as the type of the +// initializer. +private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) { + let.getPat() = lhs and + let.getInitializer() = rhs +} - /** - * Holds if `n` has complete and certain type information at `path`. - */ - pragma[nomagic] - predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } +/** + * Gets the root type of a closure. + * + * We model closures as `dyn Fn` trait object types. A closure might implement + * only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others, + * giving closures the type `dyn Fn` works well in practice -- even if not + * entirely accurate. + */ +private DynTraitType closureRootType() { + result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs` +} - /** - * Holds if `n` having type `t` at `path` conflicts with certain type information - * at `prefix`. - */ - bindingset[n, prefix, path, t] - pragma[inline_late] - predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { - inferCertainType(n, path) != t - or - // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also - // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, - // then it must be the case that `T(i+1)` is a type parameter of `certainType`, - // otherwise there is a conflict. - // - // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. - exists(TypePath suffix, TypeParameter tp, Type certainType | - path = prefix.appendInverse(suffix) and - tp = suffix.getHead() and - inferCertainType(n, prefix) = certainType and - not certainType.getATypeParameter() = tp - ) - } +/** Gets the path to a closure's return type. */ +private TypePath closureReturnPath() { + result = + TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType())) } -private Type inferLogicalOperationType(AstNode n, TypePath path) { - exists(Builtins::Bool t, BinaryLogicalOperation be | - n = [be, be.getLhs(), be.getRhs()] and - path.isEmpty() and - result = TDataType(t) - ) +/** Gets the path to a closure's `index`th parameter type, where the arity is `arity`. */ +pragma[nomagic] +private TypePath closureParameterPath(int arity, int index) { + result = + TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()), + TypePath::singleton(getTupleTypeParameter(arity, index))) +} + +private Type inferCertainStructExprType(StructExpr se, TypePath path) { + result = se.getPath().(TypeMention).getTypeAt(path) +} + +private Type inferCertainStructPatType(StructPat sp, TypePath path) { + result = sp.getPath().(TypeMention).getTypeAt(path) } private Type inferAssignmentOperationType(AstNode n, TypePath path) { @@ -780,161 +922,6 @@ private predicate bodyReturns(Expr body, Expr e) { ) } -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2` and type information should propagate in both directions - * through the type equality. - */ -private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2) - or - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - exists(LetStmt let | - let.getPat() = n1 and - let.getInitializer() = n2 - ) - or - n2 = - any(MatchExpr me | - n1 = me.getAnArm().getExpr() and - me.getNumberOfArms() = 1 - ) - or - exists(LetExpr let | - n1 = let.getScrutinee() and - n2 = let.getPat() - ) - or - exists(MatchExpr me | - n1 = me.getScrutinee() and - n2 = me.getAnArm().getPat() - ) - or - n1 = n2.(OrPat).getAPat() - or - n1 = n2.(ParenPat).getPat() - or - n1 = n2.(LiteralPat).getLiteral() - or - exists(BreakExpr break | - break.getExpr() = n1 and - break.getTarget() = n2.(LoopExpr) - ) - or - exists(AssignmentExpr be | - n1 = be.getLhs() and - n2 = be.getRhs() - ) - or - n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and - not isPanicMacroCall(n2) - or - n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() - or - bodyReturns(n1, n2) and - strictcount(Expr e | bodyReturns(n1, e)) = 1 - ) - or - n2 = - any(RefExpr re | - n1 = re.getExpr() and - prefix1.isEmpty() and - prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) - ) - or - n2 = - any(RefPat rp | - n1 = rp.getPat() and - prefix1.isEmpty() and - exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - ) - or - exists(int i, int arity | - prefix1.isEmpty() and - prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) - | - arity = n2.(TupleExpr).getNumberOfFields() and - n1 = n2.(TupleExpr).getField(i) - or - arity = n2.(TuplePat).getTupleArity() and - n1 = n2.(TuplePat).getField(i) - ) - or - exists(BlockExpr be | - n1 = be and - n2 = be.getStmtList().getTailExpr() and - if be.isAsync() - then - prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and - prefix2.isEmpty() - else ( - prefix1.isEmpty() and - prefix2.isEmpty() - ) - ) - or - // an array list expression with only one element (such as `[1]`) has type from that element - n1 = - any(ArrayListExpr ale | - ale.getAnExpr() = n2 and - ale.getNumberOfExprs() = 1 - ) and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() - or - // an array repeat expression (`[1; 3]`) has the type of the repeat operand - n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() -} - -/** - * Holds if `child` is a child of `parent`, and the Rust compiler applies [least - * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of - * `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. - * - * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound - */ -private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - child = parent.(IfExpr).getABranch() and - prefix.isEmpty() - or - parent = - any(MatchExpr me | - child = me.getAnArm().getExpr() and - me.getNumberOfArms() > 1 - ) and - prefix.isEmpty() - or - parent = - any(ArrayListExpr ale | - child = ale.getAnExpr() and - ale.getNumberOfExprs() > 1 - ) and - prefix = TypePath::singleton(getArrayTypeParameter()) - or - bodyReturns(parent, child) and - strictcount(Expr e | bodyReturns(parent, e)) > 1 and - prefix.isEmpty() - or - parent = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = child) and - prefix = closureReturnPath() - or - exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) - ) -} - private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { inferType(n, path) = TUnknownType() and // Normally, these are coercion sites, but in case a type is unknown we @@ -948,46 +935,6 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { ) } -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2`, but type information should only propagate from `n1` to - * `n2`. - */ -private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - lubCoercion(n2, n1, prefix2) and - prefix1.isEmpty() - or - exists(AstNode mid, TypePath prefixMid, TypePath suffix | - typeEquality(n1, prefixMid, mid, prefix2) or - typeEquality(mid, prefix2, n1, prefixMid) - | - lubCoercion(mid, n2, suffix) and - not lubCoercion(mid, n1, _) and - prefix1 = prefixMid.append(suffix) - ) - or - // When `n2` is `*n1` propagate type information from a raw pointer type - // parameter at `n1`. The other direction is handled in - // `inferDereferencedExprPtrType`. - n1 = n2.(DerefExpr).getExpr() and - prefix1 = TypePath::singleton(getPtrTypeParameter()) and - prefix2.isEmpty() -} - -pragma[nomagic] -private Type inferTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - typeEquality(n, prefix1, n2, prefix2) - or - typeEquality(n2, prefix2, n, prefix1) - or - typeEqualityAsymmetric(n2, prefix2, n, prefix1) - ) -} - pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -1163,53 +1110,6 @@ private module ContextTyping { ) } } - - pragma[nomagic] - private predicate hasUnknownTypeAt(AstNode n, TypePath path) { - inferType(n, path) = TUnknownType() - } - - pragma[nomagic] - private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } - - newtype FunctionPositionKind = - SelfKind() or - ReturnKind() or - PositionalKind() - - signature Type inferCallTypeSig(AstNode n, FunctionPositionKind kind, TypePath path); - - /** - * Given a predicate `inferCallType` for inferring the type of a call at a given - * position, this module exposes the predicate `check`, which wraps the input - * predicate and checks that types are only propagated into arguments when they - * are context-typed. - */ - module CheckContextTyping { - pragma[nomagic] - private Type inferCallNonReturnType( - AstNode n, FunctionPositionKind kind, TypePath prefix, TypePath path - ) { - result = inferCallType(n, kind, path) and - hasUnknownType(n) and - kind != ReturnKind() and - prefix = path.getAPrefix() - } - - pragma[nomagic] - Type check(AstNode n, TypePath path) { - result = inferCallType(n, ReturnKind(), path) - or - exists(FunctionPositionKind kind, TypePath prefix | - result = inferCallNonReturnType(n, kind, prefix, path) and - hasUnknownTypeAt(n, prefix) - | - // Never propagate type information directly into the receiver, since its type - // must already have been known in order to resolve the call - if kind = SelfKind() then not prefix.isEmpty() else any() - ) - } - } } /** @@ -2677,6 +2577,11 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput FunctionDeclaration getFunction() { result = f } + predicate isFunction(ImplOrTraitItemNodeOption i_, Function f_) { + i_ = i and + f_ = f + } + predicate isAssocFunction(ImplOrTraitItemNode i_, Function f_) { i_ = i.asSome() and f_ = f @@ -2900,22 +2805,20 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } } -private module FunctionCallMatching = MatchingWithEnvironment; - pragma[nomagic] private Type inferFunctionCallType0( FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path ) { exists(TypePath path0 | - n = call.getNodeAt(pos) and exists(string derefChainBorrow | FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) | - result = FunctionCallMatching::inferAccessType(call, derefChainBorrow, pos, path0) - or + n = call.getNodeAt(pos) and call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and result = TUnknownType() + or + result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0) ) | if @@ -2983,31 +2886,6 @@ private Type inferFunctionCallTypeSelf( ) } -private Type inferFunctionCallTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(FunctionPosition pos | - result = inferFunctionCallTypeNonSelf(n, pos, path) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - ) - or - exists(FunctionCallMatchingInput::Access a | - result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and - if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() - then kind = ContextTyping::SelfKind() - else kind = ContextTyping::PositionalKind() - ) -} - -/** - * Gets the type of `n` at `path`, where `n` is either a function call or an - * argument/receiver of a function call. - */ -private predicate inferFunctionCallType = - ContextTyping::CheckContextTyping::check/2; - abstract private class Constructor extends Addressable { final TypeParameter getTypeParameter(TypeParameterPosition ppos) { typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos) @@ -3166,15 +3044,8 @@ private module ConstructionMatchingInput implements MatchingInputSig { private module ConstructionMatching = Matching; pragma[nomagic] -private Type inferConstructionTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(ConstructionMatchingInput::Access a, FunctionPosition pos | - n = a.getNodeAt(pos) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - | +private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { + exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) | result = ConstructionMatching::inferAccessType(a, pos, path) or a.hasUnknownTypeAt(pos, path) and @@ -3183,7 +3054,7 @@ private Type inferConstructionTypePreCheck( } private predicate inferConstructionType = - ContextTyping::CheckContextTyping::check/2; + CheckContextTyping::check/2; /** * A matching configuration for resolving types of operations like `a + b`. @@ -3248,23 +3119,15 @@ private module OperationMatchingInput implements MatchingInputSig { private module OperationMatching = Matching; pragma[nomagic] -private Type inferOperationTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(OperationMatchingInput::Access a, FunctionPosition pos | +private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { + exists(OperationMatchingInput::Access a | n = a.getNodeAt(pos) and result = OperationMatching::inferAccessType(a, pos, path) and - if pos.asPosition() = 0 - then kind = ContextTyping::SelfKind() - else - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() + if pos.asPosition() = 0 then not path.isEmpty() else any() ) } -private predicate inferOperationType = - ContextTyping::CheckContextTyping::check/2; +private predicate inferOperationType = CheckContextTyping::check/2; pragma[nomagic] private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) { @@ -3285,6 +3148,19 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC ) } +/** + * Gets the struct field that the field expression `fe` resolves to, if any. + */ +cached +StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(string name, DataType ty | + ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) + | + result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or + result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) + ) +} + pragma[nomagic] private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain derefChain) { exists(string name | @@ -3293,6 +3169,20 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref ) } +/** + * Gets the tuple field that the field expression `fe` resolves to, if any. + */ +cached +TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(int i | + result = + getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) + .(StructType) + .getTypeItem() + .getTupleField(pragma[only_bind_into](i)) + ) +} + /** * A matching configuration for resolving types of field expressions like `x.field`. */ @@ -3776,170 +3666,49 @@ private Type inferCastExprType(CastExpr ce, TypePath path) { result = ce.getTypeRepr().(TypeMention).getTypeAt(path) } +/** Holds if `n` is implicitly dereferenced and/or borrowed. */ cached -private module Cached { - /** Holds if `n` is implicitly dereferenced and/or borrowed. */ - cached - predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { - exists(BorrowKind bk | - any(AssocFunctionResolution::AssocFunctionCall afc) - .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and - if bk.isNoBorrow() then borrow = false else borrow = true - ) - or - e = - any(FieldExpr fe | - exists(resolveStructFieldExpr(fe, derefChain)) - or - exists(resolveTupleFieldExpr(fe, derefChain)) - ).getContainer() and - not derefChain.isEmpty() and - borrow = false - } - - /** - * Gets an item (function or tuple struct/variant) that `call` resolves to, if - * any. - * - * The parameter `dispatch` is `true` if and only if the resolved target is a - * trait item because a precise target could not be determined from the - * types (for instance in the presence of generics or `dyn` types) - */ - cached - Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { - dispatch = false and - result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() - or - exists(ImplOrTraitItemNode i | - i instanceof TraitItemNode and dispatch = true - or - i instanceof ImplItemNode and dispatch = false - | - result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and - not call instanceof CallExprImpl::DynamicCallExpr and - not i instanceof Builtins::BuiltinImpl - ) - } - - /** - * Gets the struct field that the field expression `fe` resolves to, if any. - */ - cached - StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(string name, DataType ty | - ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) - | - result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or - result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) - ) - } - - /** - * Gets the tuple field that the field expression `fe` resolves to, if any. - */ - cached - TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(int i | - result = - getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) - .(StructType) - .getTypeItem() - .getTupleField(pragma[only_bind_into](i)) - ) - } - - /** - * Gets a type at `path` that `n` infers to, if any. - * - * The type inference implementation works by computing all possible types, so - * the result is not necessarily unique. For example, in - * - * ```rust - * trait MyTrait { - * fn foo(&self) -> &Self; - * - * fn bar(&self) -> &Self { - * self.foo() - * } - * } - * - * struct MyStruct; - * - * impl MyTrait for MyStruct { - * fn foo(&self) -> &MyStruct { - * self - * } - * } - * - * fn baz() { - * let x = MyStruct; - * x.bar(); - * } - * ``` - * - * the type inference engine will roughly make the following deductions: - * - * 1. `MyStruct` has type `MyStruct`. - * 2. `x` has type `MyStruct` (via 1.). - * 3. The return type of `bar` is `&Self`. - * 3. `x.bar()` has type `&MyStruct` (via 2 and 3, by matching the implicit `Self` - * type parameter with `MyStruct`.). - * 4. The return type of `bar` is `&MyTrait`. - * 5. `x.bar()` has type `&MyTrait` (via 2 and 4). - */ - cached - Type inferType(AstNode n, TypePath path) { - Stages::TypeInferenceStage::ref() and - result = CertainTypeInference::inferCertainType(n, path) - or - // Don't propagate type information into a node which conflicts with certain - // type information. - forall(TypePath prefix | - CertainTypeInference::hasInferredCertainType(n, prefix) and - prefix.isPrefixOf(path) - | - not CertainTypeInference::certainTypeConflict(n, prefix, path, result) - ) and - ( - result = inferAssignmentOperationType(n, path) - or - result = inferTypeEquality(n, path) - or - result = inferFunctionCallType(n, path) - or - result = inferConstructionType(n, path) - or - result = inferOperationType(n, path) - or - result = inferFieldExprType(n, path) - or - result = inferTryExprType(n, path) - or - result = inferLiteralType(n, path, false) - or - result = inferAwaitExprType(n, path) - or - result = inferDereferencedExprPtrType(n, path) - or - result = inferForLoopExprType(n, path) - or - result = inferClosureExprType(n, path) - or - result = inferArgList(n, path) - or - result = inferDeconstructionPatType(n, path) +predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { + CachedStage::ref() and + exists(BorrowKind bk | + any(AssocFunctionResolution::AssocFunctionCall afc) + .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and + if bk.isNoBorrow() then borrow = false else borrow = true + ) + or + e = + any(FieldExpr fe | + exists(resolveStructFieldExpr(fe, derefChain)) or - result = inferUnknownTypeFromAnnotation(n, path) - ) - } + exists(resolveTupleFieldExpr(fe, derefChain)) + ).getContainer() and + not derefChain.isEmpty() and + borrow = false } -import Cached - /** - * Gets a type that `n` infers to, if any. + * Gets an item (function or tuple struct/variant) that `call` resolves to, if + * any. + * + * The parameter `dispatch` is `true` if and only if the resolved target is a + * trait item because a precise target could not be determined from the + * types (for instance in the presence of generics or `dyn` types) */ -Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } +cached +Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { + dispatch = false and + result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() + or + exists(ImplOrTraitItemNode i | + i instanceof TraitItemNode and dispatch = true + or + i instanceof ImplItemNode and dispatch = false + | + result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and + not call instanceof CallExprImpl::DynamicCallExpr and + not i instanceof Builtins::BuiltinImpl + ) +} /** Provides predicates for debugging the type inference implementation. */ private module Debug { @@ -3973,11 +3742,10 @@ private module Debug { t = self.getTypeAt(path) } - predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { - n = getRelevantLocatable() and - t = inferFunctionCallType(n, path) - } - + // predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { + // n = getRelevantLocatable() and + // t = inferFunctionCallType(n, path) + // } predicate debugInferConstructionType(AstNode n, TypePath path, Type t) { n = getRelevantLocatable() and t = inferConstructionType(n, path) @@ -3990,7 +3758,7 @@ private module Debug { Type debugInferAnnotatedType(AstNode n, TypePath path) { n = getRelevantLocatable() and - result = inferAnnotatedType(n, path) + result = CertainTypeInference::inferAnnotatedType(n, path) } pragma[nomagic] diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 3344fc45f74f..f30a2c064d12 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -10252,6 +10252,7 @@ inferType | main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option | | main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & | +| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & | | main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 | @@ -11600,6 +11601,8 @@ inferType | main.rs:2221:18:2221:21 | true | | {EXTERNAL LOCATION} | bool | | main.rs:2223:9:2223:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | +| main.rs:2223:9:2223:15 | S(...) | T | main.rs:2107:5:2107:19 | S | +| main.rs:2223:9:2223:15 | S(...) | T.T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | main.rs:2107:5:2107:19 | S | @@ -11618,6 +11621,8 @@ inferType | main.rs:2224:24:2224:27 | 3i64 | | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | +| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | & | +| main.rs:2225:9:2225:15 | S(...) | T.TRef | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:29 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:29 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2225:11:2225:14 | 1i64 | | {EXTERNAL LOCATION} | i64 | diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 1f4400d8f2d7..c41ed29ebfbc 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -146,17 +146,25 @@ signature module InputSig1 { } /** - * Holds if `t` is a pseudo type. Pseudo types are skipped when checking for - * non-instantiations in `isNotInstantiationOf`. + * A special pseudo type used to represent cases where the actual type needs + * to be inferred from the context. For example, in + * + * ```rust + * let x = Vec::new(); + * x.push(42); + * ``` + * + * the element type of `x` is assigned an unknown type, which allows for type + * information to flow into `x` from the call to `push`. */ - predicate isPseudoType(Type t); + class UnknownType extends Type; /** A type parameter. */ class TypeParameter extends Type; /** - * A type abstraction. I.e., a place in the program where type variables are - * introduced. + * A type abstraction. I.e., a place in the program where type variables may + * be introduced. * * Example in C#: * ```csharp @@ -171,7 +179,7 @@ signature module InputSig1 { * ``` */ class TypeAbstraction { - /** Gets a type parameter introduced by this abstraction. */ + /** Gets a type parameter introduced by this abstraction, if any. */ TypeParameter getATypeParameter(); /** Gets a textual representation of this type abstraction. */ @@ -305,6 +313,8 @@ module Make1 Input1> { * code. For example, in * * ```csharp + * class Base { } + * * class C : Base, Interface { } * ``` * @@ -314,7 +324,7 @@ module Make1 Input1> { * `TypePath` | `Type` * ---------- | ------- * `""` | ``Base`1`` - * `"0"` | `T` + * `"B"` | `T` */ signature module InputSig2 { /** @@ -639,7 +649,8 @@ module Make1 Input1> { } private Type getNonPseudoTypeAt(App app, TypePath path) { - result = app.getTypeAt(path) and not isPseudoType(result) + result = app.getTypeAt(path) and + not result instanceof UnknownType } pragma[nomagic] @@ -2100,5 +2111,574 @@ module Make1 Input1> { not exists(tm.getTypeAt(TypePath::nil())) and exists(tm.getLocation()) } } + + /** + * Provides the input to `Make3`. + * + * TODO: Eventually align the AST signature with that of the shared CFG library. + */ + signature module InputSig3 { + /** + * References to cached predicates that should be included to the cached + * stage of type inference. Such predicates should reference `CachedStage::ref`. + */ + default predicate cachedStageRevRef() { none() } + + /** + * Point this predicate to the `inferType` predicate in the output of this module. + * + * Needed to be able to refer to `inferType` in default signature implementations. + */ + Type inferType(AstNode n, TypePath path); + + /** A boolean type. */ + class BoolType extends Type; + + /** An AST node. */ + class AstNode { + /** Gets a textual representation of this AST node. */ + string toString(); + + /** Gets the location of this AST node. */ + Location getLocation(); + } + + /** Gets the type annotation that applies to `n`, if any. */ + TypeMention getTypeAnnotation(AstNode n); + + /** An expression. */ + class Expr extends AstNode; + + /** + * A switch expression. + */ + class SwitchExpr extends Expr { + /** + * Gets the expression being switched on. + */ + Expr getExpr(); + + /** Gets the case at the specified (zero-based) `index`. */ + Case getCase(int index); + } + + /** A case in a switch expression. */ + class Case extends AstNode { + /** Gets a pattern being matched by this case. */ + AstNode getAPattern(); + + /** Gets the body of this case. */ + Expr getBody(); + } + + /** A ternary conditional expression. */ + class ConditionalExpr extends Expr { + /** Gets the condition of this expression. */ + Expr getCondition(); + + /** Gets the true branch of this expression. */ + Expr getThen(); + + /** Gets the false branch of this expression. */ + Expr getElse(); + } + + /** A binary expression. */ + class BinaryExpr extends Expr { + /** Gets the left operand of this binary expression. */ + Expr getLeftOperand(); + + /** Gets the right operand of this binary expression. */ + Expr getRightOperand(); + } + + /** A short-circuiting logical AND expression. */ + class LogicalAndExpr extends BinaryExpr; + + /** A short-circuiting logical OR expression. */ + class LogicalOrExpr extends BinaryExpr; + + /** + * An assignment expression, either compound or simple. + * + * Examples: + * + * ``` + * x = y + * sum += element + * ``` + */ + class Assignment extends BinaryExpr; + + /** A simple assignment expression, for example `x = y`. */ + class AssignExpr extends Assignment; + + /** A parenthesized expression. */ + class ParenExpr extends Expr { + Expr getExpr(); + } + + /** A variable, for example a local variable or a field. */ + class Variable { + /** Gets the AST node that defines this variable. */ + AstNode getDefiningNode(); + + /** Gets an access to this variable. */ + Expr getAnAccess(); + + /** Gets a textual representation of this element. */ + string toString(); + + /** Gets the location of this element. */ + Location getLocation(); + } + + /** + * A `let` declaration, for example a local variable declaration. + */ + class LetDeclaration extends AstNode { + /** + * Holds if this declaration is a coercion site, meaning that the type of the right + * operand may have to be coerced to the type of the left operand. + */ + predicate isCoercionSite(); + + /** Gets the left operand of this declaration. */ + AstNode getLeftOperand(); + + /** Gets the right operand of this declaration. */ + AstNode getRightOperand(); + } + + /** + * A position where a callable can have a declared type. + */ + class TypePosition { + /** Holds if this position represents the return type of a callable. */ + predicate isReturn(); + + /** Gets a textual representation of this position. */ + string toString(); + } + + /** A context needed to resolve calls. */ + bindingset[this] + class CallResolutionContext { + /** Gets a textual representation of this context. */ + bindingset[this] + string toString(); + } + + /** A callable. */ + class Callable { + TypeParameter getTypeParameter(TypeParameterPosition ppos); + + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp); + + /* Gets the declared type of this callable at `path` for position `pos`. */ + Type getDeclaredType(TypePosition pos, TypePath path); + + /** Gets a textual representation of this callable. */ + string toString(); + + /** Gets the location of this callable. */ + Location getLocation(); + } + + class Call extends Expr { + Type getTypeArgument(TypeArgumentPosition apos, TypePath path); + + AstNode getNodeAt(TypePosition pos); + + /** Gets the target of this call. */ + Callable getTargetCertain(); + + /** Gets the target of this call. */ + Callable getTarget(CallResolutionContext ctx); + } + + /** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */ + bindingset[ctx] + default Type inferCallTypeIn( + Call call, CallResolutionContext ctx, TypePosition pos, TypePath path + ) { + result = inferType(call.getNodeAt(pos), path) and + exists(ctx) + } + + Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path); + + /** + * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. + */ + predicate certainTypeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** Gets the inferred certain type of `n` at `path`. */ + Type inferCertainTypeInput(AstNode n, TypePath path); + + /** + * Holds if the types of `n1` at `path1` and `n2` at `path2` are possibly equal, + * and type information should be allowed to flow in both directions between them. + */ + predicate typeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** + * Holds if the type tree of `n1` at `path1` should be equal to the type tree + * of `n2` at `path2`, but type information should only propagate from `n1` to + * `n2`. + */ + predicate typeEqualityAsymmetricInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** + * Holds if `child` is a child of `parent` and the type of `parent` at `prefix` can be + * inferred from the type of `child`. + * + * When `child` is unique, we also allow type information to flow from `parent` to `child`. + */ + predicate parentChildType(AstNode parent, AstNode child, TypePath prefix); + + /** Gets the inferred type of `n` at `path`. */ + Type inferTypeInput(AstNode n, TypePath path); + } + + module Make3 { + private import Input3 + + /** Provides logic for inferring certain type information. */ + module CertainTypeInference { + /** Gets the type of `n`, which has an explicit type annotation. */ + pragma[nomagic] + Type inferAnnotatedType(AstNode n, TypePath path) { + result = getTypeAnnotation(n).getTypeAt(path) + } + + predicate certainTypeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + path2.isEmpty() and + ( + exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode()) + or + exists(LetDeclaration let | + not let.isCoercionSite() and + n1 = let.getLeftOperand() and + n2 = let.getRightOperand() + ) + or + n1 = n2.(ParenExpr).getExpr() + ) + or + certainTypeEqualityInput(n1, path1, n2, path2) + } + + pragma[nomagic] + private Type inferCertainTypeEquality(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferCertainType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) + | + certainTypeEquality(n, prefix1, n2, prefix2) + or + certainTypeEquality(n2, prefix2, n, prefix1) + ) + } + + private Type inferLogicalOperationType(AstNode n, TypePath path) { + ( + exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or + exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) //or + // exists(LogicalNotExpr lne | n = [lne, lne.getOperand()]) + ) and + result instanceof BoolType and + path.isEmpty() + } + + pragma[nomagic] + private Type getCertainCallExprType(Call call, TypePath path) { + exists(TypePosition ret | + ret.isReturn() and + forex(Callable target | target = call.getTargetCertain() | + result = target.getDeclaredType(ret, path) + ) + ) + } + + pragma[nomagic] + private Type inferCertainCallExprType(Call call, TypePath path) { + exists(Type ty, TypePath prefix | ty = getCertainCallExprType(call, prefix) | + exists( + Callable target, TypePath suffix, TypeParameterPosition tppos, + TypeArgumentPosition tapos + | + ty = target.getTypeParameter(tppos) and + path = prefix.append(suffix) and + result = call.getTypeArgument(tapos, suffix) and + typeArgumentParameterPositionMatch(tapos, tppos) + ) + or + not ty instanceof TypeParameter and + result = ty and + path = prefix + ) + } + + /** Gets the inferred certain type of `n` at `path`. */ + cached + Type inferCertainType(AstNode n, TypePath path) { + CachedStage::ref() and + result = inferAnnotatedType(n, path) + or + result = inferCertainTypeEquality(n, path) + or + result = inferCertainTypeInput(n, path) + or + result = inferLogicalOperationType(n, path) + or + result = inferCertainCallExprType(n, path) + or + infersCertainTypeAt(n, path, result.getATypeParameter()) + } + + /** + * Holds if `n` has complete and certain type information at the type path + * `prefix.tp`. This entails that the type at `prefix` must be the type + * that declares `tp`. + */ + pragma[nomagic] + private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { + exists(TypePath path | + exists(inferCertainType(n, path)) and + path.isSnoc(prefix, tp) + ) + } + + /** + * Holds if `n` has complete and certain type information at `path`. + */ + pragma[nomagic] + predicate hasInferredCertainType(AstNode n, TypePath path) { + exists(inferCertainType(n, path)) + } + + /** + * Holds if `n` having type `t` at `path` conflicts with certain type information + * at `prefix`. + */ + bindingset[n, prefix, path, t] + pragma[inline_late] + predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { + inferCertainType(n, path) != t + or + // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also + // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, + // then it must be the case that `T(i+1)` is a type parameter of `certainType`, + // otherwise there is a conflict. + // + // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. + exists(TypePath suffix, TypeParameter tp, Type certainType | + path = prefix.appendInverse(suffix) and + tp = suffix.getHead() and + inferCertainType(n, prefix) = certainType and + not certainType.getATypeParameter() = tp + ) + } + } + + private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + CertainTypeInference::certainTypeEquality(n1, path1, n2, path2) + or + path1.isEmpty() and + path2.isEmpty() and + ( + exists(Assignment a | + a.getLeftOperand() = n1 and + a.getRightOperand() = n2 + ) + or + exists(LetDeclaration let | + let.getLeftOperand() = n1 and + let.getRightOperand() = n2 + ) + ) + or + typeEqualityInput(n1, path1, n2, path2) + or + n2 = unique(AstNode child | parentChildType(n1, child, path1) | child) and + path2.isEmpty() + } + + private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { + parentChildType(parent, child, prefix) and + strictcount(AstNode child0 | parentChildType(parent, child0, prefix) | child0) > 1 + or + typeEqualityAsymmetricInput(child, TypePath::nil(), parent, prefix) + } + + private predicate typeEqualityAsymmetric( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { + lubCoercion(n2, n1, path2) and + path1.isEmpty() + or + exists(AstNode mid, TypePath pathMid, TypePath suffix | + typeEquality(n1, pathMid, mid, path2) or + typeEquality(mid, path2, n1, pathMid) + | + lubCoercion(mid, n2, suffix) and + not lubCoercion(mid, n1, _) and + path1 = pathMid.append(suffix) + ) + or + typeEqualityAsymmetricInput(n1, path1, n2, path2) + } + + pragma[nomagic] + private Type inferTypeEquality(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) + | + typeEquality(n, prefix1, n2, prefix2) + or + typeEquality(n2, prefix2, n, prefix1) + or + typeEqualityAsymmetric(n2, prefix2, n, prefix1) + ) + } + + /** + * Gets the inferred type of `n` at `path`. + */ + cached + Type inferType(AstNode n, TypePath path) { + CachedStage::ref() and + result = CertainTypeInference::inferCertainType(n, path) + or + // Don't propagate type information into a node which conflicts with certain + // type information. + forall(TypePath prefix | + CertainTypeInference::hasInferredCertainType(n, prefix) and + prefix.isPrefixOf(path) + | + not CertainTypeInference::certainTypeConflict(n, prefix, path, result) + ) and + ( + result = inferTypeEquality(n, path) + or + result = CheckContextTyping::check(n, path) + or + result = inferTypeInput(n, path) + ) + } + + private module TypePositionMatchingInput { + class DeclarationPosition = TypePosition; + + class AccessPosition = DeclarationPosition; + + predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) { + apos = dpos + } + } + + /** + * A matching configuration for resolving types of calls. + */ + private module CallMatchingInput implements MatchingWithEnvironmentInputSig { + import TypePositionMatchingInput + + class Declaration = Callable; + + bindingset[decl] + TypeMention getATypeParameterConstraint(TypeParameter tp, Declaration decl) { + result = Input2::getATypeParameterConstraint(tp) and + exists(decl) + or + result = decl.getAdditionalTypeParameterConstraint(tp) + } + + class AccessEnvironment = CallResolutionContext; + + final private class CallFinal = Call; + + class Access extends CallFinal { + bindingset[e] + Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) { + result = inferCallTypeIn(this, e, apos, path) + } + } + } + + private module CallMatching = MatchingWithEnvironment; + + pragma[nomagic] + Type inferCallTypeOut( + Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path + ) { + n = call.getNodeAt(pos) and + result = CallMatching::inferAccessType(call, ctx, pos, path) + } + + pragma[nomagic] + private predicate hasUnknownTypeAt(AstNode n, TypePath path) { + inferType(n, path) instanceof UnknownType + } + + pragma[nomagic] + private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } + + signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path); + + /** + * Given a predicate `inferCallType` for inferring the type of a call at a given + * position, this module exposes the predicate `check`, which wraps the input + * predicate and checks that types are only propagated into arguments when they + * are context-typed. + */ + module CheckContextTyping { + pragma[nomagic] + private Type inferCallNonReturnType(AstNode n, TypePath prefix, TypePath path) { + exists(TypePosition pos | + result = inferCallType(n, pos, path) and + hasUnknownType(n) and + not pos.isReturn() and + prefix = path.getAPrefix() + ) + } + + pragma[nomagic] + Type check(AstNode n, TypePath path) { + result = inferCallType(n, any(TypePosition pos | pos.isReturn()), path) + or + exists(TypePath prefix | + result = inferCallNonReturnType(n, prefix, path) and + hasUnknownTypeAt(n, prefix) + ) + } + } + + /** + * Gets the inferred root type of `n`, if any. + */ + Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } + + // todo: consistency checks + /** The cached stage of type inference. */ + cached + module CachedStage { + /** Reference to the cached stage of type inference. */ + cached + predicate ref() { any() } + + /** Reverse references to the predicates that reference `ref()`. */ + cached + predicate revRef() { + (exists(CertainTypeInference::inferCertainType(_, _)) implies any()) + or + (exists(inferType(_, _)) implies any()) + or + cachedStageRevRef() + } + } + } } } diff --git a/swift/ql/test/library-tests/type-inference/type-inference.expected b/swift/ql/test/library-tests/type-inference/type-inference.expected index e69de29bb2d1..79a3599aa1f3 100644 --- a/swift/ql/test/library-tests/type-inference/type-inference.expected +++ b/swift/ql/test/library-tests/type-inference/type-inference.expected @@ -0,0 +1,11 @@ +| context.swift:17:10:17:10 | C.init() | Unexpected result: target=init() | +| context.swift:17:10:17:12 | call to C.init() | Unexpected result: target=init() | +| context.swift:17:14:18:1 | // $ type=C\n | Missing result: type=C | +| context.swift:25:11:25:11 | A.init() | Unexpected result: target=init() | +| context.swift:25:11:25:13 | call to A.init() | Unexpected result: target=init() | +| context.swift:26:19:26:19 | D.init() | Unexpected result: target=init() | +| context.swift:26:19:26:21 | call to D.init() | Unexpected result: target=init() | +| context.swift:26:25:26:25 | B.init() | Unexpected result: target=init() | +| context.swift:26:25:26:27 | call to B.init() | Unexpected result: target=init() | +| file://:0:0:0:0 | A.init() | Unexpected result: target=init() | +| file://:0:0:0:0 | call to A.init() | Unexpected result: target=init() |