From 64dfe335bf1b43a66d7d823ab1f005e1be009ca9 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Tue, 5 May 2026 12:47:32 +0200 Subject: [PATCH 01/12] Rust: Move more type inference logic into shared library --- .../internal/typeinference/TypeInference.qll | 632 +++++++++--------- .../typeinference/internal/TypeInference.qll | 335 +++++++++- 2 files changed, 657 insertions(+), 310 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 423ad21ae4ac..7ecf2276f7b9 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -37,10 +37,7 @@ private module Input1 implements InputSig1 { class Type = T::Type; - predicate isPseudoType(Type t) { - t instanceof UnknownType or - t instanceof NeverType - } + class UnknownType = T::UnknownType; class TypeParameter = T::TypeParameter; @@ -276,6 +273,93 @@ private module M2 = Make2; import M2 +private module Input3 implements InputSig3 { + private import rust as Rust + + predicate cachedStageRef = CachedStage::ref/0; + + predicate cachedStageRevRef() { + (implicitDerefChainBorrow(_, _, _) implies any()) + or + (exists(resolveCallTarget(_, _)) implies any()) + or + (exists(resolveStructFieldExpr(_, _)) implies any()) + or + (exists(resolveTupleFieldExpr(_, _)) implies any()) + } + + class AstNode = Rust::AstNode; + + predicate getTypeAnnotation = getTypeAnnotation_/1; + + /** A variable, for example a local variable or a field. */ + class Variable extends Rust::Variable { + AstNode getDefiningNode() { + result = this.getPat().getName() or + result = this.getParameter().(SelfParam) + } + + AstNode getAnAccess() { result = super.getAnAccess() } + } + + abstract class Assignment extends AstNode { + abstract predicate isCoercionSite(); + + abstract AstNode getLeftOperand(); + + abstract AstNode getRightOperand(); + } + + private class LetExprAssignment extends Assignment, LetExpr { + override predicate isCoercionSite() { not this.getPat() instanceof IdentPat } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getScrutinee() } + } + + private class LetStmtAssignment extends Assignment, LetStmt { + override predicate isCoercionSite() { + this.hasTypeRepr() or + not identLetStmt(this, _, _) + } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getInitializer() } + } + + class ParenExpr extends AstNode, Rust::ParenExpr { + AstNode getExpr() { result = super.getExpr() } + } + + /** A ternary conditional expression. */ + class ConditionalExpr extends AstNode, IfExpr { + AstNode getCondition() { result = super.getCondition() } + + AstNode getThen() { result = super.getThen() } + + AstNode getElse() { result = super.getElse() } + } + + predicate certainTypeEqualityInput = CertainTypeInference_::certainTypeEquality_/4; + + predicate inferCertainTypeInput = CertainTypeInference_::inferCertainType_/2; + + predicate lubCoercionInput = lubCoercion_/3; + + predicate typeEqualityAsymmetricInput = typeEqualityAsymmetric_/4; + + predicate typeEqualityInput = typeEquality_/4; + + predicate inferTypeInput = inferType_/2; +} + +// private import Input3 +private module M3 = Make3; + +import M3 + module Consistency { import M2::Consistency @@ -428,7 +512,7 @@ private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, Typ } /** Gets the type annotation that applies to `n`, if any. */ -private TypeMention getTypeAnnotation(AstNode n) { +private TypeMention getTypeAnnotation_(AstNode n) { exists(LetStmt let | n = let.getPat() and result = let.getTypeRepr() @@ -440,16 +524,15 @@ private TypeMention getTypeAnnotation(AstNode n) { n = p.getPat() and result = p.getTypeRepr() ) -} - -/** Gets the type of `n`, which has an explicit type annotation. */ -pragma[nomagic] -private Type inferAnnotatedType(AstNode n, TypePath path) { - result = getTypeAnnotation(n).getTypeAt(path) or - result = n.(ShorthandSelfParameterMention).getTypeAt(path) + result = n.(ShorthandSelfParameterMention) } +// /** Gets the type of `n`, which has an explicit type annotation. */ +// pragma[nomagic] +// private Type inferAnnotatedType(AstNode n, TypePath path) { +// result = getTypeAnnotation(n).getTypeAt(path) +// } pragma[nomagic] private Type inferFunctionBodyType(AstNode n, TypePath path) { exists(Function f | @@ -508,7 +591,7 @@ private TypePath closureParameterPath(int arity, int index) { } /** Module for inferring certain type information. */ -module CertainTypeInference { +module CertainTypeInference_ { pragma[nomagic] private predicate callResolvesTo(CallExpr ce, Path p, Function f) { p = CallExprImpl::getFunctionPath(ce) and @@ -572,31 +655,31 @@ module CertainTypeInference { result = sp.getPath().(TypeMention).getTypeAt(path) } - predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - exists(Variable v | n1 = v.getAnAccess() | - n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam) - ) - or - // A `let` statement with a type annotation is a coercion site and hence - // is not a certain type equality. - exists(LetStmt let | - not let.hasTypeRepr() and - identLetStmt(let, n1, n2) - ) - or - exists(LetExpr let | - // Similarly as for let statements, we need to rule out binding modes - // changing the type. - let.getPat().(IdentPat) = n1 and - let.getScrutinee() = n2 - ) - or - n1 = n2.(ParenExpr).getExpr() - ) - or + predicate certainTypeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // prefix1.isEmpty() and + // prefix2.isEmpty() and + // ( + // exists(Variable v | n1 = v.getAnAccess() | + // n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam) + // ) + // or + // // // A `let` statement with a type annotation is a coercion site and hence + // // // is not a certain type equality. + // // exists(LetStmt let | + // // not let.hasTypeRepr() and + // // identLetStmt(let, n1, n2) + // // ) + // // or + // // exists(LetExpr let | + // // // Similarly as for let statements, we need to rule out binding modes + // // // changing the type. + // // let.getPat().(IdentPat) = n1 and + // // let.getScrutinee() = n2 + // // ) + // // or + // n1 = n2.(ParenExpr).getExpr() + // ) + // or n1 = any(IdentPat ip | n2 = ip.getName() and @@ -625,33 +708,31 @@ module CertainTypeInference { ) } - pragma[nomagic] - private Type inferCertainTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferCertainType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - certainTypeEquality(n, prefix1, n2, prefix2) - or - certainTypeEquality(n2, prefix2, n, prefix1) - ) - } - + // pragma[nomagic] + // private Type inferCertainTypeEquality(AstNode n, TypePath path) { + // exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + // result = inferCertainType(n2, prefix2.appendInverse(suffix)) and + // path = prefix1.append(suffix) + // | + // certainTypeEquality(n, prefix1, n2, prefix2) + // or + // certainTypeEquality(n2, prefix2, n, prefix1) + // ) + // } /** * Holds if `n` has complete and certain type information and if `n` has the * resulting type at `path`. */ - cached - Type inferCertainType(AstNode n, TypePath path) { - result = inferAnnotatedType(n, path) and - Stages::TypeInferenceStage::ref() - or + Type inferCertainType_(AstNode n, TypePath path) { + // result = inferAnnotatedType(n, path) and + // Stages::TypeInferenceStage::ref() + // or result = inferFunctionBodyType(n, path) or result = inferCertainCallExprType(n, path) or - result = inferCertainTypeEquality(n, path) - or + // result = inferCertainTypeEquality(n, path) + // or result = inferLiteralType(n, path, true) or result = inferRefPatType(n) and @@ -690,51 +771,48 @@ module CertainTypeInference { n instanceof ClosureExpr and path.isEmpty() and result = closureRootType() - or - infersCertainTypeAt(n, path, result.getATypeParameter()) - } - - /** - * Holds if `n` has complete and certain type information at the type path - * `prefix.tp`. This entails that the type at `prefix` must be the type - * that declares `tp`. - */ - pragma[nomagic] - private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { - exists(TypePath path | - exists(inferCertainType(n, path)) and - path.isSnoc(prefix, tp) - ) - } - - /** - * Holds if `n` has complete and certain type information at `path`. - */ - pragma[nomagic] - predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } - - /** - * Holds if `n` having type `t` at `path` conflicts with certain type information - * at `prefix`. - */ - bindingset[n, prefix, path, t] - pragma[inline_late] - predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { - inferCertainType(n, path) != t - or - // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also - // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, - // then it must be the case that `T(i+1)` is a type parameter of `certainType`, - // otherwise there is a conflict. - // - // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. - exists(TypePath suffix, TypeParameter tp, Type certainType | - path = prefix.appendInverse(suffix) and - tp = suffix.getHead() and - inferCertainType(n, prefix) = certainType and - not certainType.getATypeParameter() = tp - ) - } + // or + // infersCertainTypeAt(n, path, result.getATypeParameter()) + } + // /** + // * Holds if `n` has complete and certain type information at the type path + // * `prefix.tp`. This entails that the type at `prefix` must be the type + // * that declares `tp`. + // */ + // pragma[nomagic] + // private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { + // exists(TypePath path | + // exists(inferCertainType(n, path)) and + // path.isSnoc(prefix, tp) + // ) + // } + // /** + // * Holds if `n` has complete and certain type information at `path`. + // */ + // pragma[nomagic] + // predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } + // /** + // * Holds if `n` having type `t` at `path` conflicts with certain type information + // * at `prefix`. + // */ + // bindingset[n, prefix, path, t] + // pragma[inline_late] + // predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { + // inferCertainType(n, path) != t + // or + // // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also + // // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, + // // then it must be the case that `T(i+1)` is a type parameter of `certainType`, + // // otherwise there is a conflict. + // // + // // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. + // exists(TypePath suffix, TypeParameter tp, Type certainType | + // path = prefix.appendInverse(suffix) and + // tp = suffix.getHead() and + // inferCertainType(n, prefix) = certainType and + // not certainType.getATypeParameter() = tp + // ) + // } } private Type inferLogicalOperationType(AstNode n, TypePath path) { @@ -785,28 +863,28 @@ private predicate bodyReturns(Expr body, Expr e) { * of `n2` at `prefix2` and type information should propagate in both directions * through the type equality. */ -private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2) - or +private predicate typeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2) + // or prefix1.isEmpty() and prefix2.isEmpty() and ( - exists(LetStmt let | - let.getPat() = n1 and - let.getInitializer() = n2 - ) - or + // exists(LetStmt let | + // let.getPat() = n1 and + // let.getInitializer() = n2 + // ) + // or n2 = any(MatchExpr me | n1 = me.getAnArm().getExpr() and me.getNumberOfArms() = 1 ) or - exists(LetExpr let | - n1 = let.getScrutinee() and - n2 = let.getPat() - ) - or + // exists(LetExpr let | + // n1 = let.getScrutinee() and + // n2 = let.getPat() + // ) + // or exists(MatchExpr me | n1 = me.getScrutinee() and n2 = me.getAnArm().getPat() @@ -903,10 +981,10 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat * * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound */ -private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - child = parent.(IfExpr).getABranch() and - prefix.isEmpty() - or +private predicate lubCoercion_(AstNode parent, AstNode child, TypePath prefix) { + // child = parent.(IfExpr).getABranch() and + // prefix.isEmpty() + // or parent = any(MatchExpr me | child = me.getAnArm().getExpr() and @@ -953,19 +1031,19 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { * of `n2` at `prefix2`, but type information should only propagate from `n1` to * `n2`. */ -private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - lubCoercion(n2, n1, prefix2) and - prefix1.isEmpty() - or - exists(AstNode mid, TypePath prefixMid, TypePath suffix | - typeEquality(n1, prefixMid, mid, prefix2) or - typeEquality(mid, prefix2, n1, prefixMid) - | - lubCoercion(mid, n2, suffix) and - not lubCoercion(mid, n1, _) and - prefix1 = prefixMid.append(suffix) - ) - or +private predicate typeEqualityAsymmetric_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // lubCoercion(n2, n1, prefix2) and + // prefix1.isEmpty() + // or + // exists(AstNode mid, TypePath prefixMid, TypePath suffix | + // typeEquality(n1, prefixMid, mid, prefix2) or + // typeEquality(mid, prefix2, n1, prefixMid) + // | + // lubCoercion(mid, n2, suffix) and + // not lubCoercion(mid, n1, _) and + // prefix1 = prefixMid.append(suffix) + // ) + // or // When `n2` is `*n1` propagate type information from a raw pointer type // parameter at `n1`. The other direction is handled in // `inferDereferencedExprPtrType`. @@ -974,20 +1052,19 @@ private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n prefix2.isEmpty() } -pragma[nomagic] -private Type inferTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - typeEquality(n, prefix1, n2, prefix2) - or - typeEquality(n2, prefix2, n, prefix1) - or - typeEqualityAsymmetric(n2, prefix2, n, prefix1) - ) -} - +// pragma[nomagic] +// private Type inferTypeEquality(AstNode n, TypePath path) { +// exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | +// result = inferType(n2, prefix2.appendInverse(suffix)) and +// path = prefix1.append(suffix) +// | +// typeEquality(n, prefix1, n2, prefix2) +// or +// typeEquality(n2, prefix2, n, prefix1) +// or +// typeEqualityAsymmetric(n2, prefix2, n, prefix1) +// ) +// } pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -3776,170 +3853,121 @@ private Type inferCastExprType(CastExpr ce, TypePath path) { result = ce.getTypeRepr().(TypeMention).getTypeAt(path) } +/** Holds if `n` is implicitly dereferenced and/or borrowed. */ cached -private module Cached { - /** Holds if `n` is implicitly dereferenced and/or borrowed. */ - cached - predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { - exists(BorrowKind bk | - any(AssocFunctionResolution::AssocFunctionCall afc) - .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and - if bk.isNoBorrow() then borrow = false else borrow = true - ) - or - e = - any(FieldExpr fe | - exists(resolveStructFieldExpr(fe, derefChain)) - or - exists(resolveTupleFieldExpr(fe, derefChain)) - ).getContainer() and - not derefChain.isEmpty() and - borrow = false - } - - /** - * Gets an item (function or tuple struct/variant) that `call` resolves to, if - * any. - * - * The parameter `dispatch` is `true` if and only if the resolved target is a - * trait item because a precise target could not be determined from the - * types (for instance in the presence of generics or `dyn` types) - */ - cached - Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { - dispatch = false and - result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() - or - exists(ImplOrTraitItemNode i | - i instanceof TraitItemNode and dispatch = true +predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { + CachedStage::ref() and + exists(BorrowKind bk | + any(AssocFunctionResolution::AssocFunctionCall afc) + .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and + if bk.isNoBorrow() then borrow = false else borrow = true + ) + or + e = + any(FieldExpr fe | + exists(resolveStructFieldExpr(fe, derefChain)) or - i instanceof ImplItemNode and dispatch = false - | - result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and - not call instanceof CallExprImpl::DynamicCallExpr and - not i instanceof Builtins::BuiltinImpl - ) - } - - /** - * Gets the struct field that the field expression `fe` resolves to, if any. - */ - cached - StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(string name, DataType ty | - ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) - | - result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or - result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) - ) - } - - /** - * Gets the tuple field that the field expression `fe` resolves to, if any. - */ - cached - TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(int i | - result = - getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) - .(StructType) - .getTypeItem() - .getTupleField(pragma[only_bind_into](i)) - ) - } + exists(resolveTupleFieldExpr(fe, derefChain)) + ).getContainer() and + not derefChain.isEmpty() and + borrow = false +} - /** - * Gets a type at `path` that `n` infers to, if any. - * - * The type inference implementation works by computing all possible types, so - * the result is not necessarily unique. For example, in - * - * ```rust - * trait MyTrait { - * fn foo(&self) -> &Self; - * - * fn bar(&self) -> &Self { - * self.foo() - * } - * } - * - * struct MyStruct; - * - * impl MyTrait for MyStruct { - * fn foo(&self) -> &MyStruct { - * self - * } - * } - * - * fn baz() { - * let x = MyStruct; - * x.bar(); - * } - * ``` - * - * the type inference engine will roughly make the following deductions: - * - * 1. `MyStruct` has type `MyStruct`. - * 2. `x` has type `MyStruct` (via 1.). - * 3. The return type of `bar` is `&Self`. - * 3. `x.bar()` has type `&MyStruct` (via 2 and 3, by matching the implicit `Self` - * type parameter with `MyStruct`.). - * 4. The return type of `bar` is `&MyTrait`. - * 5. `x.bar()` has type `&MyTrait` (via 2 and 4). - */ - cached - Type inferType(AstNode n, TypePath path) { - Stages::TypeInferenceStage::ref() and - result = CertainTypeInference::inferCertainType(n, path) +/** + * Gets an item (function or tuple struct/variant) that `call` resolves to, if + * any. + * + * The parameter `dispatch` is `true` if and only if the resolved target is a + * trait item because a precise target could not be determined from the + * types (for instance in the presence of generics or `dyn` types) + */ +cached +Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { + dispatch = false and + result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() + or + exists(ImplOrTraitItemNode i | + i instanceof TraitItemNode and dispatch = true or - // Don't propagate type information into a node which conflicts with certain - // type information. - forall(TypePath prefix | - CertainTypeInference::hasInferredCertainType(n, prefix) and - prefix.isPrefixOf(path) - | - not CertainTypeInference::certainTypeConflict(n, prefix, path, result) - ) and - ( - result = inferAssignmentOperationType(n, path) - or - result = inferTypeEquality(n, path) - or - result = inferFunctionCallType(n, path) - or - result = inferConstructionType(n, path) - or - result = inferOperationType(n, path) - or - result = inferFieldExprType(n, path) - or - result = inferTryExprType(n, path) - or - result = inferLiteralType(n, path, false) - or - result = inferAwaitExprType(n, path) - or - result = inferDereferencedExprPtrType(n, path) - or - result = inferForLoopExprType(n, path) - or - result = inferClosureExprType(n, path) - or - result = inferArgList(n, path) - or - result = inferDeconstructionPatType(n, path) - or - result = inferUnknownTypeFromAnnotation(n, path) - ) - } + i instanceof ImplItemNode and dispatch = false + | + result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and + not call instanceof CallExprImpl::DynamicCallExpr and + not i instanceof Builtins::BuiltinImpl + ) } -import Cached +/** + * Gets the struct field that the field expression `fe` resolves to, if any. + */ +cached +StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(string name, DataType ty | + ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) + | + result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or + result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) + ) +} /** - * Gets a type that `n` infers to, if any. + * Gets the tuple field that the field expression `fe` resolves to, if any. */ -Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } +cached +TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(int i | + result = + getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) + .(StructType) + .getTypeItem() + .getTupleField(pragma[only_bind_into](i)) + ) +} + +private Type inferType_(AstNode n, TypePath path) { + // // Stages::TypeInferenceStage::ref() and + // // result = CertainTypeInference::inferCertainType(n, path) + // // or + // // Don't propagate type information into a node which conflicts with certain + // // type information. + // forall(TypePath prefix | + // CertainTypeInference::hasInferredCertainType(n, prefix) and + // prefix.isPrefixOf(path) + // | + // not CertainTypeInference::certainTypeConflict(n, prefix, path, result) + // ) and + // ( + result = inferAssignmentOperationType(n, path) + or + // result = inferTypeEquality(n, path) + // or + result = inferFunctionCallType(n, path) + or + result = inferConstructionType(n, path) + or + result = inferOperationType(n, path) + or + result = inferFieldExprType(n, path) + or + result = inferTryExprType(n, path) + or + result = inferLiteralType(n, path, false) + or + result = inferAwaitExprType(n, path) + or + result = inferDereferencedExprPtrType(n, path) + or + result = inferForLoopExprType(n, path) + or + result = inferClosureExprType(n, path) + or + result = inferArgList(n, path) + or + result = inferDeconstructionPatType(n, path) + or + result = inferUnknownTypeFromAnnotation(n, path) + // ) +} /** Provides predicates for debugging the type inference implementation. */ private module Debug { @@ -3990,7 +4018,7 @@ private module Debug { Type debugInferAnnotatedType(AstNode n, TypePath path) { n = getRelevantLocatable() and - result = inferAnnotatedType(n, path) + result = CertainTypeInference::inferAnnotatedType(n, path) } pragma[nomagic] diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index cf82d77b5e1d..18998d3d974f 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -146,17 +146,25 @@ signature module InputSig1 { } /** - * Holds if `t` is a pseudo type. Pseudo types are skipped when checking for - * non-instantiations in `isNotInstantiationOf`. + * A special pseudo type used to represent cases where the actual type needs + * to be inferred from the context. For example, in + * + * ```rust + * let x = Vec::new(); + * x.push(42); + * ``` + * + * the element type of `x` is assigned an unknown type, which allows for type + * information to flow into `x` from the call to `push`. */ - predicate isPseudoType(Type t); + class UnknownType extends Type; /** A type parameter. */ class TypeParameter extends Type; /** - * A type abstraction. I.e., a place in the program where type variables are - * introduced. + * A type abstraction. I.e., a place in the program where type variables may + * be introduced. * * Example in C#: * ```csharp @@ -171,7 +179,7 @@ signature module InputSig1 { * ``` */ class TypeAbstraction { - /** Gets a type parameter introduced by this abstraction. */ + /** Gets a type parameter introduced by this abstraction, if any. */ TypeParameter getATypeParameter(); /** Gets a textual representation of this type abstraction. */ @@ -332,6 +340,8 @@ module Make1 Input1> { * code. For example, in * * ```csharp + * class Base { } + * * class C : Base, Interface { } * ``` * @@ -341,7 +351,7 @@ module Make1 Input1> { * `TypePath` | `Type` * ---------- | ------- * `""` | ``Base`1`` - * `"0"` | `T` + * `"B"` | `T` */ signature module InputSig2 { /** @@ -666,7 +676,8 @@ module Make1 Input1> { } private Type getNonPseudoTypeAt(App app, TypePath path) { - result = app.getTypeAt(path) and not isPseudoType(result) + result = app.getTypeAt(path) and + not result instanceof UnknownType } pragma[nomagic] @@ -2127,5 +2138,313 @@ module Make1 Input1> { not exists(tm.getTypeAt(TypePath::nil())) and exists(tm.getLocation()) } } + + /** + * Provides the input to `Make3`. + */ + signature module InputSig3 { + /** + * Reference to the cached stage of type inference. Should be instantiated + * with `CachedStage::ref()`. + */ + predicate cachedStageRef(); + + /** + * Reference to the cached stage of type inference. Should be instantiated + * with `CachedStage::ref()`. + */ + default predicate cachedStageRevRef() { none() } + + /** An AST node. */ + class AstNode { + /** Gets a textual representation of this AST node. */ + string toString(); + + /** Gets the location of this AST node. */ + Location getLocation(); + } + + /** Gets the type annotation that applies to `n`, if any. */ + TypeMention getTypeAnnotation(AstNode n); + + /** A variable, for example a local variable or a field. */ + class Variable { + AstNode getDefiningNode(); + + AstNode getAnAccess(); + + /** Gets a textual representation of this element. */ + string toString(); + + /** Gets the location of this element. */ + Location getLocation(); + } + + /** + * An assignment where type information can flow from one operand to the + * other. + */ + class Assignment extends AstNode { + /** + * Holds if this assignment is a coercion site, meaning that the type of the right + * operand may have to be coerced to the type of the left operand. + */ + predicate isCoercionSite(); + + /** Gets the left operand of this binary expression. */ + AstNode getLeftOperand(); + + /** Gets the right operand of this binary expression. */ + AstNode getRightOperand(); + } + + /** A parenthesized expression. */ + class ParenExpr extends AstNode { + AstNode getExpr(); + } + + /** A ternary conditional expression. */ + class ConditionalExpr extends AstNode { + /** Gets the condition of this expression. */ + AstNode getCondition(); + + /** Gets the true branch of this expression. */ + AstNode getThen(); + + /** Gets the false branch of this expression. */ + AstNode getElse(); + } + + /** + * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. + */ + predicate certainTypeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** Gets the inferred certain type of `n` at `path`. */ + Type inferCertainTypeInput(AstNode n, TypePath path); + + /** + * Holds if `child` is a child of `parent`, and a least upper bound (LUB) coercion + * may be applied to infer the type of `parent` from the type of `child`. + * + * In this case, we want type information to only flow from `child` to `parent`, + * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial + * explosion in inferred types. + */ + predicate lubCoercionInput(AstNode parent, AstNode child, TypePath prefix); + + /** + * Holds if the type tree of `n1` at `path1` should be equal to the type tree + * of `n2` at `prefix2`, but type information should only propagate from `n1` to + * `n2`. + */ + predicate typeEqualityAsymmetricInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** + * Holds if the types of `n1` at `path1` and `n2` at `path2` are possibly equal. + */ + predicate typeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** Gets the inferred type of `n` at `path`. */ + Type inferTypeInput(AstNode n, TypePath path); + } + + module Make3 { + private import Input3 + + /** Provides logic for inferring certain type information. */ + module CertainTypeInference { + /** Gets the type of `n`, which has an explicit type annotation. */ + pragma[nomagic] + Type inferAnnotatedType(AstNode n, TypePath path) { + result = getTypeAnnotation(n).getTypeAt(path) + } + + predicate certainTypeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + path2.isEmpty() and + ( + exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode()) + or + exists(Assignment a | + not a.isCoercionSite() and + n1 = a.getLeftOperand() and + n2 = a.getRightOperand() + ) + or + n1 = n2.(ParenExpr).getExpr() + ) + or + certainTypeEqualityInput(n1, path1, n2, path2) + } + + pragma[nomagic] + private Type inferCertainTypeEquality(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferCertainType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) + | + certainTypeEquality(n, prefix1, n2, prefix2) + or + certainTypeEquality(n2, prefix2, n, prefix1) + ) + } + + /** Gets the inferred certain type of `n` at `path`. */ + cached + Type inferCertainType(AstNode n, TypePath path) { + result = inferAnnotatedType(n, path) and + cachedStageRef() + or + result = inferCertainTypeEquality(n, path) + or + result = inferCertainTypeInput(n, path) + or + infersCertainTypeAt(n, path, result.getATypeParameter()) + } + + /** + * Holds if `n` has complete and certain type information at the type path + * `prefix.tp`. This entails that the type at `prefix` must be the type + * that declares `tp`. + */ + pragma[nomagic] + private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { + exists(TypePath path | + exists(inferCertainType(n, path)) and + path.isSnoc(prefix, tp) + ) + } + + /** + * Holds if `n` has complete and certain type information at `path`. + */ + pragma[nomagic] + predicate hasInferredCertainType(AstNode n, TypePath path) { + exists(inferCertainType(n, path)) + } + + /** + * Holds if `n` having type `t` at `path` conflicts with certain type information + * at `prefix`. + */ + bindingset[n, prefix, path, t] + pragma[inline_late] + predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { + inferCertainType(n, path) != t + or + // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also + // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, + // then it must be the case that `T(i+1)` is a type parameter of `certainType`, + // otherwise there is a conflict. + // + // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. + exists(TypePath suffix, TypeParameter tp, Type certainType | + path = prefix.appendInverse(suffix) and + tp = suffix.getHead() and + inferCertainType(n, prefix) = certainType and + not certainType.getATypeParameter() = tp + ) + } + } + + private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + CertainTypeInference::certainTypeEquality(n1, path1, n2, path2) + or + path1.isEmpty() and + path2.isEmpty() and + exists(Assignment a | + a.getLeftOperand() = n1 and + a.getRightOperand() = n2 + ) + or + typeEqualityInput(n1, path1, n2, path2) + } + + private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { + parent = any(ConditionalExpr ce | child = [ce.getThen(), ce.getElse()]) and + prefix.isEmpty() + or + lubCoercionInput(parent, child, prefix) + } + + private predicate typeEqualityAsymmetric( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { + lubCoercion(n2, n1, path2) and + path1.isEmpty() + or + exists(AstNode mid, TypePath pathMid, TypePath suffix | + typeEquality(n1, pathMid, mid, path2) or + typeEquality(mid, path2, n1, pathMid) + | + lubCoercion(mid, n2, suffix) and + not lubCoercion(mid, n1, _) and + path1 = pathMid.append(suffix) + ) + or + typeEqualityAsymmetricInput(n1, path1, n2, path2) + } + + pragma[nomagic] + private Type inferTypeEquality(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) + | + typeEquality(n, prefix1, n2, prefix2) + or + typeEquality(n2, prefix2, n, prefix1) + or + typeEqualityAsymmetric(n2, prefix2, n, prefix1) + ) + } + + /** + * Gets the inferred type of `n` at `path`. + */ + cached + Type inferType(AstNode n, TypePath path) { + cachedStageRef() and + result = CertainTypeInference::inferCertainType(n, path) + or + // Don't propagate type information into a node which conflicts with certain + // type information. + forall(TypePath prefix | + CertainTypeInference::hasInferredCertainType(n, prefix) and + prefix.isPrefixOf(path) + | + not CertainTypeInference::certainTypeConflict(n, prefix, path, result) + ) and + ( + result = inferTypeEquality(n, path) + or + result = inferTypeInput(n, path) + ) + } + + /** + * Gets the inferred root type of `n`, if any. + */ + Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } + + /** The cached stage of type inference. */ + cached + module CachedStage { + /** Reference to the cached stage of type inference. */ + cached + predicate ref() { any() } + + /** Reverse references to the predicates that reference `ref()`. */ + cached + predicate revRef() { + (exists(CertainTypeInference::inferCertainType(_, _)) implies any()) + or + (exists(inferType(_, _)) implies any()) + or + cachedStageRevRef() + } + } + } } } From c067d4782e3dd49df672f1cd48a7c6eda2a7733c Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Tue, 5 May 2026 13:27:39 +0200 Subject: [PATCH 02/12] wip --- .../internal/typeinference/TypeInference.qll | 231 ++++-------------- .../typeinference/internal/TypeInference.qll | 2 +- 2 files changed, 44 insertions(+), 189 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 7ecf2276f7b9..e1cbbda6f4f4 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -290,9 +290,22 @@ private module Input3 implements InputSig3 { class AstNode = Rust::AstNode; - predicate getTypeAnnotation = getTypeAnnotation_/1; + TypeMention getTypeAnnotation(AstNode n) { + exists(LetStmt let | + n = let.getPat() and + result = let.getTypeRepr() + ) + or + result = n.(SelfParam).getTypeRepr() + or + exists(Param p | + n = p.getPat() and + result = p.getTypeRepr() + ) + or + result = n.(ShorthandSelfParameterMention) + } - /** A variable, for example a local variable or a field. */ class Variable extends Rust::Variable { AstNode getDefiningNode() { result = this.getPat().getName() or @@ -511,28 +524,6 @@ private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, Typ ) } -/** Gets the type annotation that applies to `n`, if any. */ -private TypeMention getTypeAnnotation_(AstNode n) { - exists(LetStmt let | - n = let.getPat() and - result = let.getTypeRepr() - ) - or - result = n.(SelfParam).getTypeRepr() - or - exists(Param p | - n = p.getPat() and - result = p.getTypeRepr() - ) - or - result = n.(ShorthandSelfParameterMention) -} - -// /** Gets the type of `n`, which has an explicit type annotation. */ -// pragma[nomagic] -// private Type inferAnnotatedType(AstNode n, TypePath path) { -// result = getTypeAnnotation(n).getTypeAt(path) -// } pragma[nomagic] private Type inferFunctionBodyType(AstNode n, TypePath path) { exists(Function f | @@ -656,30 +647,6 @@ module CertainTypeInference_ { } predicate certainTypeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - // prefix1.isEmpty() and - // prefix2.isEmpty() and - // ( - // exists(Variable v | n1 = v.getAnAccess() | - // n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam) - // ) - // or - // // // A `let` statement with a type annotation is a coercion site and hence - // // // is not a certain type equality. - // // exists(LetStmt let | - // // not let.hasTypeRepr() and - // // identLetStmt(let, n1, n2) - // // ) - // // or - // // exists(LetExpr let | - // // // Similarly as for let statements, we need to rule out binding modes - // // // changing the type. - // // let.getPat().(IdentPat) = n1 and - // // let.getScrutinee() = n2 - // // ) - // // or - // n1 = n2.(ParenExpr).getExpr() - // ) - // or n1 = any(IdentPat ip | n2 = ip.getName() and @@ -708,31 +675,15 @@ module CertainTypeInference_ { ) } - // pragma[nomagic] - // private Type inferCertainTypeEquality(AstNode n, TypePath path) { - // exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - // result = inferCertainType(n2, prefix2.appendInverse(suffix)) and - // path = prefix1.append(suffix) - // | - // certainTypeEquality(n, prefix1, n2, prefix2) - // or - // certainTypeEquality(n2, prefix2, n, prefix1) - // ) - // } /** * Holds if `n` has complete and certain type information and if `n` has the * resulting type at `path`. */ Type inferCertainType_(AstNode n, TypePath path) { - // result = inferAnnotatedType(n, path) and - // Stages::TypeInferenceStage::ref() - // or result = inferFunctionBodyType(n, path) or result = inferCertainCallExprType(n, path) or - // result = inferCertainTypeEquality(n, path) - // or result = inferLiteralType(n, path, true) or result = inferRefPatType(n) and @@ -771,48 +722,7 @@ module CertainTypeInference_ { n instanceof ClosureExpr and path.isEmpty() and result = closureRootType() - // or - // infersCertainTypeAt(n, path, result.getATypeParameter()) - } - // /** - // * Holds if `n` has complete and certain type information at the type path - // * `prefix.tp`. This entails that the type at `prefix` must be the type - // * that declares `tp`. - // */ - // pragma[nomagic] - // private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { - // exists(TypePath path | - // exists(inferCertainType(n, path)) and - // path.isSnoc(prefix, tp) - // ) - // } - // /** - // * Holds if `n` has complete and certain type information at `path`. - // */ - // pragma[nomagic] - // predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } - // /** - // * Holds if `n` having type `t` at `path` conflicts with certain type information - // * at `prefix`. - // */ - // bindingset[n, prefix, path, t] - // pragma[inline_late] - // predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { - // inferCertainType(n, path) != t - // or - // // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also - // // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, - // // then it must be the case that `T(i+1)` is a type parameter of `certainType`, - // // otherwise there is a conflict. - // // - // // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. - // exists(TypePath suffix, TypeParameter tp, Type certainType | - // path = prefix.appendInverse(suffix) and - // tp = suffix.getHead() and - // inferCertainType(n, prefix) = certainType and - // not certainType.getATypeParameter() = tp - // ) - // } + } } private Type inferLogicalOperationType(AstNode n, TypePath path) { @@ -864,27 +774,15 @@ private predicate bodyReturns(Expr body, Expr e) { * through the type equality. */ private predicate typeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - // CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2) - // or prefix1.isEmpty() and prefix2.isEmpty() and ( - // exists(LetStmt let | - // let.getPat() = n1 and - // let.getInitializer() = n2 - // ) - // or n2 = any(MatchExpr me | n1 = me.getAnArm().getExpr() and me.getNumberOfArms() = 1 ) or - // exists(LetExpr let | - // n1 = let.getScrutinee() and - // n2 = let.getPat() - // ) - // or exists(MatchExpr me | n1 = me.getScrutinee() and n2 = me.getAnArm().getPat() @@ -982,9 +880,6 @@ private predicate typeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePa * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound */ private predicate lubCoercion_(AstNode parent, AstNode child, TypePath prefix) { - // child = parent.(IfExpr).getABranch() and - // prefix.isEmpty() - // or parent = any(MatchExpr me | child = me.getAnArm().getExpr() and @@ -1032,18 +927,6 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { * `n2`. */ private predicate typeEqualityAsymmetric_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - // lubCoercion(n2, n1, prefix2) and - // prefix1.isEmpty() - // or - // exists(AstNode mid, TypePath prefixMid, TypePath suffix | - // typeEquality(n1, prefixMid, mid, prefix2) or - // typeEquality(mid, prefix2, n1, prefixMid) - // | - // lubCoercion(mid, n2, suffix) and - // not lubCoercion(mid, n1, _) and - // prefix1 = prefixMid.append(suffix) - // ) - // or // When `n2` is `*n1` propagate type information from a raw pointer type // parameter at `n1`. The other direction is handled in // `inferDereferencedExprPtrType`. @@ -1052,19 +935,6 @@ private predicate typeEqualityAsymmetric_(AstNode n1, TypePath prefix1, AstNode prefix2.isEmpty() } -// pragma[nomagic] -// private Type inferTypeEquality(AstNode n, TypePath path) { -// exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | -// result = inferType(n2, prefix2.appendInverse(suffix)) and -// path = prefix1.append(suffix) -// | -// typeEquality(n, prefix1, n2, prefix2) -// or -// typeEquality(n2, prefix2, n, prefix1) -// or -// typeEqualityAsymmetric(n2, prefix2, n, prefix1) -// ) -// } pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -3362,6 +3232,19 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC ) } +/** + * Gets the struct field that the field expression `fe` resolves to, if any. + */ +cached +StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(string name, DataType ty | + ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) + | + result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or + result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) + ) +} + pragma[nomagic] private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain derefChain) { exists(string name | @@ -3370,6 +3253,20 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref ) } +/** + * Gets the tuple field that the field expression `fe` resolves to, if any. + */ +cached +TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { + exists(int i | + result = + getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) + .(StructType) + .getTypeItem() + .getTupleField(pragma[only_bind_into](i)) + ) +} + /** * A matching configuration for resolving types of field expressions like `x.field`. */ @@ -3897,50 +3794,9 @@ Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { ) } -/** - * Gets the struct field that the field expression `fe` resolves to, if any. - */ -cached -StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(string name, DataType ty | - ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) - | - result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or - result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) - ) -} - -/** - * Gets the tuple field that the field expression `fe` resolves to, if any. - */ -cached -TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(int i | - result = - getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) - .(StructType) - .getTypeItem() - .getTupleField(pragma[only_bind_into](i)) - ) -} - private Type inferType_(AstNode n, TypePath path) { - // // Stages::TypeInferenceStage::ref() and - // // result = CertainTypeInference::inferCertainType(n, path) - // // or - // // Don't propagate type information into a node which conflicts with certain - // // type information. - // forall(TypePath prefix | - // CertainTypeInference::hasInferredCertainType(n, prefix) and - // prefix.isPrefixOf(path) - // | - // not CertainTypeInference::certainTypeConflict(n, prefix, path, result) - // ) and - // ( result = inferAssignmentOperationType(n, path) or - // result = inferTypeEquality(n, path) - // or result = inferFunctionCallType(n, path) or result = inferConstructionType(n, path) @@ -3966,7 +3822,6 @@ private Type inferType_(AstNode n, TypePath path) { result = inferDeconstructionPatType(n, path) or result = inferUnknownTypeFromAnnotation(n, path) - // ) } /** Provides predicates for debugging the type inference implementation. */ diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 18998d3d974f..b544e8495483 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2235,7 +2235,7 @@ module Make1 Input1> { /** * Holds if the type tree of `n1` at `path1` should be equal to the type tree - * of `n2` at `prefix2`, but type information should only propagate from `n1` to + * of `n2` at `path2`, but type information should only propagate from `n1` to * `n2`. */ predicate typeEqualityAsymmetricInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); From 3629bb751213bea43f548671ef42a70d1a4e48ef Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Tue, 5 May 2026 13:45:25 +0200 Subject: [PATCH 03/12] wip2 --- .../internal/typeinference/TypeInference.qll | 428 +++++++++--------- .../typeinference/internal/TypeInference.qll | 18 +- 2 files changed, 211 insertions(+), 235 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index e1cbbda6f4f4..269e4835803e 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -276,8 +276,6 @@ import M2 private module Input3 implements InputSig3 { private import rust as Rust - predicate cachedStageRef = CachedStage::ref/0; - predicate cachedStageRevRef() { (implicitDerefChainBorrow(_, _, _) implies any()) or @@ -342,6 +340,14 @@ private module Input3 implements InputSig3 { override AstNode getRightOperand() { result = this.getInitializer() } } + private class AssignmentExprAssignment extends Assignment, AssignmentExpr { + override predicate isCoercionSite() { any() } + + override AstNode getLeftOperand() { result = this.getLhs() } + + override AstNode getRightOperand() { result = this.getRhs() } + } + class ParenExpr extends AstNode, Rust::ParenExpr { AstNode getExpr() { result = super.getExpr() } } @@ -355,20 +361,207 @@ private module Input3 implements InputSig3 { AstNode getElse() { result = super.getElse() } } - predicate certainTypeEqualityInput = CertainTypeInference_::certainTypeEquality_/4; + predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + n1 = + any(IdentPat ip | + n2 = ip.getName() and + prefix1.isEmpty() and + if ip.isRef() + then + exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + else prefix2.isEmpty() + ) + or + exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | + n1 = dce.getArgList() and + tt.getArity() = dce.getNumberOfSyntacticArguments() and + n2 = dce.getSyntacticPositionalArgument(i) and + prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and + prefix2.isEmpty() + ) + or + exists(ClosureExpr ce, int index | + n1 = ce and + n2 = ce.getParam(index).getPat() and + prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and + prefix2.isEmpty() + ) + } - predicate inferCertainTypeInput = CertainTypeInference_::inferCertainType_/2; + predicate inferCertainTypeInput = CertainTypeInferenceInput::inferCertainTypeInput/2; - predicate lubCoercionInput = lubCoercion_/3; + /** + * Holds if `child` is a child of `parent`, and the Rust compiler applies [least + * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of + * `child`. + * + * In this case, we want type information to only flow from `child` to `parent`, + * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial + * explosion in inferred types. + * + * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound + */ + predicate lubCoercionInput(AstNode parent, AstNode child, TypePath prefix) { + parent = + any(MatchExpr me | + child = me.getAnArm().getExpr() and + me.getNumberOfArms() > 1 + ) and + prefix.isEmpty() + or + parent = + any(ArrayListExpr ale | + child = ale.getAnExpr() and + ale.getNumberOfExprs() > 1 + ) and + prefix = TypePath::singleton(getArrayTypeParameter()) + or + bodyReturns(parent, child) and + strictcount(Expr e | bodyReturns(parent, e)) > 1 and + prefix.isEmpty() + or + parent = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = child) and + prefix = closureReturnPath() + or + exists(Struct s | + child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and + prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and + s = getRangeType(parent) + ) + } - predicate typeEqualityAsymmetricInput = typeEqualityAsymmetric_/4; + predicate typeEqualityAsymmetricInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // When `n2` is `*n1` propagate type information from a raw pointer type + // parameter at `n1`. The other direction is handled in + // `inferDereferencedExprPtrType`. + n1 = n2.(DerefExpr).getExpr() and + prefix1 = TypePath::singleton(getPtrTypeParameter()) and + prefix2.isEmpty() + } - predicate typeEqualityInput = typeEquality_/4; + predicate typeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + prefix1.isEmpty() and + prefix2.isEmpty() and + ( + n2 = + any(MatchExpr me | + n1 = me.getAnArm().getExpr() and + me.getNumberOfArms() = 1 + ) + or + exists(MatchExpr me | + n1 = me.getScrutinee() and + n2 = me.getAnArm().getPat() + ) + or + n1 = n2.(OrPat).getAPat() + or + n1 = n2.(ParenPat).getPat() + or + n1 = n2.(LiteralPat).getLiteral() + or + exists(BreakExpr break | + break.getExpr() = n1 and + break.getTarget() = n2.(LoopExpr) + ) + or + n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and + not isPanicMacroCall(n2) + or + n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() + or + bodyReturns(n1, n2) and + strictcount(Expr e | bodyReturns(n1, e)) = 1 + ) + or + n2 = + any(RefExpr re | + n1 = re.getExpr() and + prefix1.isEmpty() and + prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) + ) + or + n2 = + any(RefPat rp | + n1 = rp.getPat() and + prefix1.isEmpty() and + exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + ) + or + exists(int i, int arity | + prefix1.isEmpty() and + prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) + | + arity = n2.(TupleExpr).getNumberOfFields() and + n1 = n2.(TupleExpr).getField(i) + or + arity = n2.(TuplePat).getTupleArity() and + n1 = n2.(TuplePat).getField(i) + ) + or + exists(BlockExpr be | + n1 = be and + n2 = be.getStmtList().getTailExpr() and + if be.isAsync() + then + prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and + prefix2.isEmpty() + else ( + prefix1.isEmpty() and + prefix2.isEmpty() + ) + ) + or + // an array list expression with only one element (such as `[1]`) has type from that element + n1 = + any(ArrayListExpr ale | + ale.getAnExpr() = n2 and + ale.getNumberOfExprs() = 1 + ) and + prefix1 = TypePath::singleton(getArrayTypeParameter()) and + prefix2.isEmpty() + or + // an array repeat expression (`[1; 3]`) has the type of the repeat operand + n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and + prefix1 = TypePath::singleton(getArrayTypeParameter()) and + prefix2.isEmpty() + } - predicate inferTypeInput = inferType_/2; + Type inferTypeInput(AstNode n, TypePath path) { + result = inferAssignmentOperationType(n, path) + or + result = inferFunctionCallType(n, path) + or + result = inferConstructionType(n, path) + or + result = inferOperationType(n, path) + or + result = inferFieldExprType(n, path) + or + result = inferTryExprType(n, path) + or + result = inferLiteralType(n, path, false) + or + result = inferAwaitExprType(n, path) + or + result = inferDereferencedExprPtrType(n, path) + or + result = inferForLoopExprType(n, path) + or + result = inferClosureExprType(n, path) + or + result = inferArgList(n, path) + or + result = inferDeconstructionPatType(n, path) + or + result = inferUnknownTypeFromAnnotation(n, path) + } } -// private import Input3 private module M3 = Make3; import M3 @@ -582,7 +775,7 @@ private TypePath closureParameterPath(int arity, int index) { } /** Module for inferring certain type information. */ -module CertainTypeInference_ { +private module CertainTypeInferenceInput { pragma[nomagic] private predicate callResolvesTo(CallExpr ce, Path p, Function f) { p = CallExprImpl::getFunctionPath(ce) and @@ -646,40 +839,11 @@ module CertainTypeInference_ { result = sp.getPath().(TypeMention).getTypeAt(path) } - predicate certainTypeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - n1 = - any(IdentPat ip | - n2 = ip.getName() and - prefix1.isEmpty() and - if ip.isRef() - then - exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - else prefix2.isEmpty() - ) - or - exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | - n1 = dce.getArgList() and - tt.getArity() = dce.getNumberOfSyntacticArguments() and - n2 = dce.getSyntacticPositionalArgument(i) and - prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and - prefix2.isEmpty() - ) - or - exists(ClosureExpr ce, int index | - n1 = ce and - n2 = ce.getParam(index).getPat() and - prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and - prefix2.isEmpty() - ) - } - /** * Holds if `n` has complete and certain type information and if `n` has the * resulting type at `path`. */ - Type inferCertainType_(AstNode n, TypePath path) { + Type inferCertainTypeInput(AstNode n, TypePath path) { result = inferFunctionBodyType(n, path) or result = inferCertainCallExprType(n, path) @@ -768,146 +932,6 @@ private predicate bodyReturns(Expr body, Expr e) { ) } -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2` and type information should propagate in both directions - * through the type equality. - */ -private predicate typeEquality_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - n2 = - any(MatchExpr me | - n1 = me.getAnArm().getExpr() and - me.getNumberOfArms() = 1 - ) - or - exists(MatchExpr me | - n1 = me.getScrutinee() and - n2 = me.getAnArm().getPat() - ) - or - n1 = n2.(OrPat).getAPat() - or - n1 = n2.(ParenPat).getPat() - or - n1 = n2.(LiteralPat).getLiteral() - or - exists(BreakExpr break | - break.getExpr() = n1 and - break.getTarget() = n2.(LoopExpr) - ) - or - exists(AssignmentExpr be | - n1 = be.getLhs() and - n2 = be.getRhs() - ) - or - n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and - not isPanicMacroCall(n2) - or - n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() - or - bodyReturns(n1, n2) and - strictcount(Expr e | bodyReturns(n1, e)) = 1 - ) - or - n2 = - any(RefExpr re | - n1 = re.getExpr() and - prefix1.isEmpty() and - prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) - ) - or - n2 = - any(RefPat rp | - n1 = rp.getPat() and - prefix1.isEmpty() and - exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - ) - or - exists(int i, int arity | - prefix1.isEmpty() and - prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) - | - arity = n2.(TupleExpr).getNumberOfFields() and - n1 = n2.(TupleExpr).getField(i) - or - arity = n2.(TuplePat).getTupleArity() and - n1 = n2.(TuplePat).getField(i) - ) - or - exists(BlockExpr be | - n1 = be and - n2 = be.getStmtList().getTailExpr() and - if be.isAsync() - then - prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and - prefix2.isEmpty() - else ( - prefix1.isEmpty() and - prefix2.isEmpty() - ) - ) - or - // an array list expression with only one element (such as `[1]`) has type from that element - n1 = - any(ArrayListExpr ale | - ale.getAnExpr() = n2 and - ale.getNumberOfExprs() = 1 - ) and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() - or - // an array repeat expression (`[1; 3]`) has the type of the repeat operand - n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() -} - -/** - * Holds if `child` is a child of `parent`, and the Rust compiler applies [least - * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of - * `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. - * - * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound - */ -private predicate lubCoercion_(AstNode parent, AstNode child, TypePath prefix) { - parent = - any(MatchExpr me | - child = me.getAnArm().getExpr() and - me.getNumberOfArms() > 1 - ) and - prefix.isEmpty() - or - parent = - any(ArrayListExpr ale | - child = ale.getAnExpr() and - ale.getNumberOfExprs() > 1 - ) and - prefix = TypePath::singleton(getArrayTypeParameter()) - or - bodyReturns(parent, child) and - strictcount(Expr e | bodyReturns(parent, e)) > 1 and - prefix.isEmpty() - or - parent = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = child) and - prefix = closureReturnPath() - or - exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) - ) -} - private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { inferType(n, path) = TUnknownType() and // Normally, these are coercion sites, but in case a type is unknown we @@ -921,20 +945,6 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { ) } -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2`, but type information should only propagate from `n1` to - * `n2`. - */ -private predicate typeEqualityAsymmetric_(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - // When `n2` is `*n1` propagate type information from a raw pointer type - // parameter at `n1`. The other direction is handled in - // `inferDereferencedExprPtrType`. - n1 = n2.(DerefExpr).getExpr() and - prefix1 = TypePath::singleton(getPtrTypeParameter()) and - prefix2.isEmpty() -} - pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -3794,36 +3804,6 @@ Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { ) } -private Type inferType_(AstNode n, TypePath path) { - result = inferAssignmentOperationType(n, path) - or - result = inferFunctionCallType(n, path) - or - result = inferConstructionType(n, path) - or - result = inferOperationType(n, path) - or - result = inferFieldExprType(n, path) - or - result = inferTryExprType(n, path) - or - result = inferLiteralType(n, path, false) - or - result = inferAwaitExprType(n, path) - or - result = inferDereferencedExprPtrType(n, path) - or - result = inferForLoopExprType(n, path) - or - result = inferClosureExprType(n, path) - or - result = inferArgList(n, path) - or - result = inferDeconstructionPatType(n, path) - or - result = inferUnknownTypeFromAnnotation(n, path) -} - /** Provides predicates for debugging the type inference implementation. */ private module Debug { Locatable getRelevantLocatable() { diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index b544e8495483..b3298834b8ce 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2144,14 +2144,8 @@ module Make1 Input1> { */ signature module InputSig3 { /** - * Reference to the cached stage of type inference. Should be instantiated - * with `CachedStage::ref()`. - */ - predicate cachedStageRef(); - - /** - * Reference to the cached stage of type inference. Should be instantiated - * with `CachedStage::ref()`. + * References to cached predicates that should be included to the cached + * stage of type inference. Such predicates should reference `CachedStage::ref`. */ default predicate cachedStageRevRef() { none() } @@ -2169,8 +2163,10 @@ module Make1 Input1> { /** A variable, for example a local variable or a field. */ class Variable { + /** Gets the AST node that defines this variable. */ AstNode getDefiningNode(); + /** Gets an access to this variable. */ AstNode getAnAccess(); /** Gets a textual representation of this element. */ @@ -2293,8 +2289,8 @@ module Make1 Input1> { /** Gets the inferred certain type of `n` at `path`. */ cached Type inferCertainType(AstNode n, TypePath path) { - result = inferAnnotatedType(n, path) and - cachedStageRef() + CachedStage::ref() and + result = inferAnnotatedType(n, path) or result = inferCertainTypeEquality(n, path) or @@ -2405,7 +2401,7 @@ module Make1 Input1> { */ cached Type inferType(AstNode n, TypePath path) { - cachedStageRef() and + CachedStage::ref() and result = CertainTypeInference::inferCertainType(n, path) or // Don't propagate type information into a node which conflicts with certain From a95a6194cc23280b463d96de69cd9a25b4ed220a Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Tue, 5 May 2026 15:30:10 +0200 Subject: [PATCH 04/12] wip3 --- .../internal/typeinference/TypeInference.qll | 120 +++++++----------- .../typeinference/internal/TypeInference.qll | 38 ++---- 2 files changed, 58 insertions(+), 100 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 269e4835803e..c679c0ffaa17 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -352,15 +352,6 @@ private module Input3 implements InputSig3 { AstNode getExpr() { result = super.getExpr() } } - /** A ternary conditional expression. */ - class ConditionalExpr extends AstNode, IfExpr { - AstNode getCondition() { result = super.getCondition() } - - AstNode getThen() { result = super.getThen() } - - AstNode getElse() { result = super.getElse() } - } - predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { n1 = any(IdentPat ip | @@ -392,65 +383,10 @@ private module Input3 implements InputSig3 { predicate inferCertainTypeInput = CertainTypeInferenceInput::inferCertainTypeInput/2; - /** - * Holds if `child` is a child of `parent`, and the Rust compiler applies [least - * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of - * `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. - * - * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound - */ - predicate lubCoercionInput(AstNode parent, AstNode child, TypePath prefix) { - parent = - any(MatchExpr me | - child = me.getAnArm().getExpr() and - me.getNumberOfArms() > 1 - ) and - prefix.isEmpty() - or - parent = - any(ArrayListExpr ale | - child = ale.getAnExpr() and - ale.getNumberOfExprs() > 1 - ) and - prefix = TypePath::singleton(getArrayTypeParameter()) - or - bodyReturns(parent, child) and - strictcount(Expr e | bodyReturns(parent, e)) > 1 and - prefix.isEmpty() - or - parent = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = child) and - prefix = closureReturnPath() - or - exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) - ) - } - - predicate typeEqualityAsymmetricInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - // When `n2` is `*n1` propagate type information from a raw pointer type - // parameter at `n1`. The other direction is handled in - // `inferDereferencedExprPtrType`. - n1 = n2.(DerefExpr).getExpr() and - prefix1 = TypePath::singleton(getPtrTypeParameter()) and - prefix2.isEmpty() - } - predicate typeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { prefix1.isEmpty() and prefix2.isEmpty() and ( - n2 = - any(MatchExpr me | - n1 = me.getAnArm().getExpr() and - me.getNumberOfArms() = 1 - ) - or exists(MatchExpr me | n1 = me.getScrutinee() and n2 = me.getAnArm().getPat() @@ -471,9 +407,6 @@ private module Input3 implements InputSig3 { not isPanicMacroCall(n2) or n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() - or - bodyReturns(n1, n2) and - strictcount(Expr e | bodyReturns(n1, e)) = 1 ) or n2 = @@ -516,21 +449,56 @@ private module Input3 implements InputSig3 { ) ) or - // an array list expression with only one element (such as `[1]`) has type from that element - n1 = - any(ArrayListExpr ale | - ale.getAnExpr() = n2 and - ale.getNumberOfExprs() = 1 - ) and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() - or // an array repeat expression (`[1; 3]`) has the type of the repeat operand n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and prefix1 = TypePath::singleton(getArrayTypeParameter()) and prefix2.isEmpty() } + predicate typeEqualityAsymmetricInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // When `n2` is `*n1` propagate type information from a raw pointer type + // parameter at `n1`. The other direction is handled in + // `inferDereferencedExprPtrType`. + n1 = n2.(DerefExpr).getExpr() and + prefix1 = TypePath::singleton(getPtrTypeParameter()) and + prefix2.isEmpty() + or + n2 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n1) and + prefix2 = closureReturnPath() and + prefix1.isEmpty() + } + + /** + * Holds if `child` is a child of `parent`, and the Rust compiler applies [least + * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of + * `child`. + * + * In this case, we want type information to only flow from `child` to `parent`, + * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial + * explosion in inferred types. + * + * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound + */ + predicate parentChildType(AstNode parent, AstNode child, TypePath prefix) { + child = parent.(IfExpr).getABranch() and + prefix.isEmpty() + or + parent = any(MatchExpr me | child = me.getAnArm().getExpr()) and + prefix.isEmpty() + or + parent = any(ArrayListExpr ale | child = ale.getAnExpr()) and + prefix = TypePath::singleton(getArrayTypeParameter()) + or + bodyReturns(parent, child) and + prefix.isEmpty() + or + exists(Struct s | + child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and + prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and + s = getRangeType(parent) + ) + } + Type inferTypeInput(AstNode n, TypePath path) { result = inferAssignmentOperationType(n, path) or diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index b3298834b8ce..c8c11e7de78e 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2199,18 +2199,6 @@ module Make1 Input1> { AstNode getExpr(); } - /** A ternary conditional expression. */ - class ConditionalExpr extends AstNode { - /** Gets the condition of this expression. */ - AstNode getCondition(); - - /** Gets the true branch of this expression. */ - AstNode getThen(); - - /** Gets the false branch of this expression. */ - AstNode getElse(); - } - /** * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. */ @@ -2220,14 +2208,10 @@ module Make1 Input1> { Type inferCertainTypeInput(AstNode n, TypePath path); /** - * Holds if `child` is a child of `parent`, and a least upper bound (LUB) coercion - * may be applied to infer the type of `parent` from the type of `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. + * Holds if the types of `n1` at `path1` and `n2` at `path2` are possibly equal, + * and type information should be allowed to flow in both directions between them. */ - predicate lubCoercionInput(AstNode parent, AstNode child, TypePath prefix); + predicate typeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); /** * Holds if the type tree of `n1` at `path1` should be equal to the type tree @@ -2237,9 +2221,12 @@ module Make1 Input1> { predicate typeEqualityAsymmetricInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); /** - * Holds if the types of `n1` at `path1` and `n2` at `path2` are possibly equal. + * Holds if `child` is a child of `parent` and the type of `parent` at `prefix` can be + * inferred from the type of `child`. + * + * When `child` is unique, we also allow type information to flow from `parent` to `child`. */ - predicate typeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + predicate parentChildType(AstNode parent, AstNode child, TypePath prefix); /** Gets the inferred type of `n` at `path`. */ Type inferTypeInput(AstNode n, TypePath path); @@ -2355,13 +2342,16 @@ module Make1 Input1> { ) or typeEqualityInput(n1, path1, n2, path2) + or + n2 = unique(AstNode child | parentChildType(n1, child, path1) | child) and + path2.isEmpty() } private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - parent = any(ConditionalExpr ce | child = [ce.getThen(), ce.getElse()]) and - prefix.isEmpty() + parentChildType(parent, child, prefix) and + strictcount(AstNode child0 | parentChildType(parent, child0, prefix) | child0) > 1 or - lubCoercionInput(parent, child, prefix) + typeEqualityAsymmetricInput(child, TypePath::nil(), parent, prefix) } private predicate typeEqualityAsymmetric( From ed321bf955702abcac36868968b5269ee5962053 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Wed, 6 May 2026 10:22:00 +0200 Subject: [PATCH 05/12] wip4 --- .../internal/typeinference/TypeInference.qll | 62 +++++----- .../typeinference/internal/TypeInference.qll | 106 ++++++++++++++---- 2 files changed, 123 insertions(+), 45 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index c679c0ffaa17..fb8a78fa05f1 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -286,6 +286,10 @@ private module Input3 implements InputSig3 { (exists(resolveTupleFieldExpr(_, _)) implies any()) } + class BoolType extends DataType { + BoolType() { this.getTypeItem() instanceof Builtins::Bool } + } + class AstNode = Rust::AstNode; TypeMention getTypeAnnotation(AstNode n) { @@ -304,16 +308,44 @@ private module Input3 implements InputSig3 { result = n.(ShorthandSelfParameterMention) } + class Expr = Rust::Expr; + + class ConditionalExpr extends AstNode, IfExpr { + Expr getCondition() { result = super.getCondition() } + + Expr getThen() { result = super.getThen() } + + Expr getElse() { result = super.getElse() } + } + + class BinaryExpr extends AstNode, Rust::BinaryExpr { + Expr getLeftOperand() { result = super.getLhs() } + + Expr getRightOperand() { result = super.getRhs() } + } + + class LogicalAndExpr extends BinaryExpr, Rust::LogicalAndExpr { } + + class LogicalOrExpr extends BinaryExpr, Rust::LogicalOrExpr { } + + abstract class Assignment extends BinaryExpr { } + + class AssignExpr extends Assignment, Rust::AssignmentExpr { } + + class ParenExpr extends AstNode, Rust::ParenExpr { + AstNode getExpr() { result = super.getExpr() } + } + class Variable extends Rust::Variable { AstNode getDefiningNode() { result = this.getPat().getName() or result = this.getParameter().(SelfParam) } - AstNode getAnAccess() { result = super.getAnAccess() } + Expr getAnAccess() { result = super.getAnAccess() } } - abstract class Assignment extends AstNode { + abstract class LetDeclaration extends AstNode { abstract predicate isCoercionSite(); abstract AstNode getLeftOperand(); @@ -321,7 +353,7 @@ private module Input3 implements InputSig3 { abstract AstNode getRightOperand(); } - private class LetExprAssignment extends Assignment, LetExpr { + private class LetExprLetDeclaration extends LetDeclaration, LetExpr { override predicate isCoercionSite() { not this.getPat() instanceof IdentPat } override AstNode getLeftOperand() { result = this.getPat() } @@ -329,7 +361,7 @@ private module Input3 implements InputSig3 { override AstNode getRightOperand() { result = this.getScrutinee() } } - private class LetStmtAssignment extends Assignment, LetStmt { + private class LetStmtLetDeclaration extends LetDeclaration, LetStmt { override predicate isCoercionSite() { this.hasTypeRepr() or not identLetStmt(this, _, _) @@ -340,18 +372,6 @@ private module Input3 implements InputSig3 { override AstNode getRightOperand() { result = this.getInitializer() } } - private class AssignmentExprAssignment extends Assignment, AssignmentExpr { - override predicate isCoercionSite() { any() } - - override AstNode getLeftOperand() { result = this.getLhs() } - - override AstNode getRightOperand() { result = this.getRhs() } - } - - class ParenExpr extends AstNode, Rust::ParenExpr { - AstNode getExpr() { result = super.getExpr() } - } - predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { n1 = any(IdentPat ip | @@ -824,8 +844,6 @@ private module CertainTypeInferenceInput { result = inferRefExprType(n) and path.isEmpty() or - result = inferLogicalOperationType(n, path) - or result = inferCertainStructExprType(n, path) or result = inferCertainStructPatType(n, path) @@ -857,14 +875,6 @@ private module CertainTypeInferenceInput { } } -private Type inferLogicalOperationType(AstNode n, TypePath path) { - exists(Builtins::Bool t, BinaryLogicalOperation be | - n = [be, be.getLhs(), be.getRhs()] and - path.isEmpty() and - result = TDataType(t) - ) -} - private Type inferAssignmentOperationType(AstNode n, TypePath path) { n instanceof AssignmentOperation and path.isEmpty() and diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index c8c11e7de78e..5dad50dac2ee 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2141,6 +2141,8 @@ module Make1 Input1> { /** * Provides the input to `Make3`. + * + * TODO: Eventually align the AST signature with that of the shared CFG library. */ signature module InputSig3 { /** @@ -2149,6 +2151,9 @@ module Make1 Input1> { */ default predicate cachedStageRevRef() { none() } + /** A boolean type. */ + class BoolType extends Type; + /** An AST node. */ class AstNode { /** Gets a textual representation of this AST node. */ @@ -2161,13 +2166,63 @@ module Make1 Input1> { /** Gets the type annotation that applies to `n`, if any. */ TypeMention getTypeAnnotation(AstNode n); + /** An expression. */ + class Expr extends AstNode; + + /** A ternary conditional expression. */ + class ConditionalExpr extends Expr { + /** Gets the condition of this expression. */ + Expr getCondition(); + + /** Gets the true branch of this expression. */ + Expr getThen(); + + /** Gets the false branch of this expression. */ + Expr getElse(); + } + + /** A binary expression. */ + class BinaryExpr extends Expr { + /** Gets the left operand of this binary expression. */ + Expr getLeftOperand(); + + /** Gets the right operand of this binary expression. */ + Expr getRightOperand(); + } + + /** A short-circuiting logical AND expression. */ + class LogicalAndExpr extends BinaryExpr; + + /** A short-circuiting logical OR expression. */ + class LogicalOrExpr extends BinaryExpr; + + /** + * An assignment expression, either compound or simple. + * + * Examples: + * + * ``` + * x = y + * sum += element + * ``` + */ + class Assignment extends BinaryExpr; + + /** A simple assignment expression, for example `x = y`. */ + class AssignExpr extends Assignment; + + /** A parenthesized expression. */ + class ParenExpr extends AstNode { + AstNode getExpr(); + } + /** A variable, for example a local variable or a field. */ class Variable { /** Gets the AST node that defines this variable. */ AstNode getDefiningNode(); /** Gets an access to this variable. */ - AstNode getAnAccess(); + Expr getAnAccess(); /** Gets a textual representation of this element. */ string toString(); @@ -2177,28 +2232,22 @@ module Make1 Input1> { } /** - * An assignment where type information can flow from one operand to the - * other. + * A `let` declaration, for example a local variable declaration. */ - class Assignment extends AstNode { + class LetDeclaration extends AstNode { /** - * Holds if this assignment is a coercion site, meaning that the type of the right + * Holds if this declaration is a coercion site, meaning that the type of the right * operand may have to be coerced to the type of the left operand. */ predicate isCoercionSite(); - /** Gets the left operand of this binary expression. */ + /** Gets the left operand of this declaration. */ AstNode getLeftOperand(); - /** Gets the right operand of this binary expression. */ + /** Gets the right operand of this declaration. */ AstNode getRightOperand(); } - /** A parenthesized expression. */ - class ParenExpr extends AstNode { - AstNode getExpr(); - } - /** * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. */ @@ -2249,10 +2298,10 @@ module Make1 Input1> { ( exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode()) or - exists(Assignment a | - not a.isCoercionSite() and - n1 = a.getLeftOperand() and - n2 = a.getRightOperand() + exists(LetDeclaration let | + not let.isCoercionSite() and + n1 = let.getLeftOperand() and + n2 = let.getRightOperand() ) or n1 = n2.(ParenExpr).getExpr() @@ -2273,6 +2322,16 @@ module Make1 Input1> { ) } + private Type inferLogicalOperationType(AstNode n, TypePath path) { + ( + exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or + exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) //or + // exists(LogicalNotExpr lne | n = [lne, lne.getOperand()]) + ) and + result instanceof BoolType and + path.isEmpty() + } + /** Gets the inferred certain type of `n` at `path`. */ cached Type inferCertainType(AstNode n, TypePath path) { @@ -2283,6 +2342,8 @@ module Make1 Input1> { or result = inferCertainTypeInput(n, path) or + result = inferLogicalOperationType(n, path) + or infersCertainTypeAt(n, path, result.getATypeParameter()) } @@ -2336,9 +2397,16 @@ module Make1 Input1> { or path1.isEmpty() and path2.isEmpty() and - exists(Assignment a | - a.getLeftOperand() = n1 and - a.getRightOperand() = n2 + ( + exists(Assignment a | + a.getLeftOperand() = n1 and + a.getRightOperand() = n2 + ) + or + exists(LetDeclaration let | + let.getLeftOperand() = n1 and + let.getRightOperand() = n2 + ) ) or typeEqualityInput(n1, path1, n2, path2) From 447c8fab08401e75add597721188a76efcea484b Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Wed, 6 May 2026 11:41:30 +0200 Subject: [PATCH 06/12] wip6 --- .../internal/typeinference/TypeInference.qll | 143 +++++++++--------- .../typeinference/internal/TypeInference.qll | 55 ++++++- .../type-inference/type-inference.expected | 11 ++ 3 files changed, 136 insertions(+), 73 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index fb8a78fa05f1..4d2df25c7fe9 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -310,15 +310,11 @@ private module Input3 implements InputSig3 { class Expr = Rust::Expr; - class ConditionalExpr extends AstNode, IfExpr { - Expr getCondition() { result = super.getCondition() } - + class ConditionalExpr extends IfExpr { Expr getThen() { result = super.getThen() } - - Expr getElse() { result = super.getElse() } } - class BinaryExpr extends AstNode, Rust::BinaryExpr { + class BinaryExpr extends Rust::BinaryExpr { Expr getLeftOperand() { result = super.getLhs() } Expr getRightOperand() { result = super.getRhs() } @@ -332,9 +328,7 @@ private module Input3 implements InputSig3 { class AssignExpr extends Assignment, Rust::AssignmentExpr { } - class ParenExpr extends AstNode, Rust::ParenExpr { - AstNode getExpr() { result = super.getExpr() } - } + class ParenExpr = Rust::ParenExpr; class Variable extends Rust::Variable { AstNode getDefiningNode() { @@ -372,6 +366,43 @@ private module Input3 implements InputSig3 { override AstNode getRightOperand() { result = this.getInitializer() } } + class CallTarget extends FunctionCallMatchingInput::Declaration { + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) { + result = + tp.(TypeParamTypeParameter) + .getTypeParam() + .getAdditionalTypeBound(this.getFunction(), _) + .getTypeRepr() + } + + Type getReturnType(TypePath path) { + exists(FunctionPosition pos | + pos.isReturn() and + result = super.getDeclaredType(pos, path) + ) + } + + Type getParameterType(int index, TypePath path) { + none() // todo + } + } + + class Call extends Expr instanceof FunctionCallMatchingInput::Access { + Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { + result = super.getTypeArgument(apos, path) + } + + /** Gets the target of this call. */ + CallTarget getTargetCertain() { + exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p | + result.isFunction(i, f) and + p = CallExprImpl::getFunctionPath(this) and + f = resolvePath(p) and + f.isDirectlyFor(i) + ) + } + } + predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { n1 = any(IdentPat ip | @@ -686,16 +717,38 @@ private class AssocFunctionDeclaration extends FunctionDeclaration { } pragma[nomagic] -private TypeMention getCallExprTypeMentionArgument(CallExpr ce, TypeArgumentPosition apos) { - exists(Path p, int i | p = CallExprImpl::getFunctionPath(ce) | - apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i)) and - result = getPathTypeArgument(p, pragma[only_bind_into](i)) +private TypePath getPathToImplSelfTypeParam(TypeParam tp) { + exists(ImplItemNode impl | + tp = impl.getTypeParam(_) and + TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) ) } pragma[nomagic] private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, TypePath path) { - result = getCallExprTypeMentionArgument(ce, apos).getTypeAt(path) + exists(Path p, ItemNode resolved, TypeParam tp | + p = CallExprImpl::getFunctionPath(ce) and + resolved = resolvePath(p) and + apos.asTypeParam() = tp + | + // For type parameters of the function we must resolve their + // instantiation from the path. For instance, for `fn bar(a: A) -> A` + // and the path `bar`, we must resolve `A` to `i64`. + exists(int i | + tp = resolved.getTypeParam(pragma[only_bind_into](i)) and + result = getPathTypeArgument(p, pragma[only_bind_into](i)).getTypeAt(path) + ) + or + // For type parameters of the `impl` block we must resolve their + // instantiation from the path. For instance, for `impl for Foo` + // and the path `Foo::bar` we must resolve `A` to `i64`. + exists(ImplItemNode impl, TypePath pathToTp | + resolved = impl.getASuccessor(_) and + tp = impl.getTypeParam(_) and + pathToTp = getPathToImplSelfTypeParam(tp) and + result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(path)) + ) + ) or // Handle constructions that use `Self(...)` syntax exists(Path p, TypePath path0 | @@ -764,61 +817,6 @@ private TypePath closureParameterPath(int arity, int index) { /** Module for inferring certain type information. */ private module CertainTypeInferenceInput { - pragma[nomagic] - private predicate callResolvesTo(CallExpr ce, Path p, Function f) { - p = CallExprImpl::getFunctionPath(ce) and - f = resolvePath(p) - } - - pragma[nomagic] - private Type getCallExprType(CallExpr ce, Path p, FunctionDeclaration f, TypePath path) { - exists(ImplOrTraitItemNodeOption i | - callResolvesTo(ce, p, f) and - result = f.getReturnType(i, path) and - f.isDirectlyFor(i) - ) - } - - pragma[nomagic] - private Type getCertainCallExprType(CallExpr ce, Path p, TypePath tp) { - forex(Function f | callResolvesTo(ce, p, f) | result = getCallExprType(ce, p, f, tp)) - } - - pragma[nomagic] - private TypePath getPathToImplSelfTypeParam(TypeParam tp) { - exists(ImplItemNode impl | - tp = impl.getTypeParam(_) and - TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) - ) - } - - pragma[nomagic] - private Type inferCertainCallExprType(CallExpr ce, TypePath path) { - exists(Type ty, TypePath prefix, Path p | ty = getCertainCallExprType(ce, p, prefix) | - exists(TypePath suffix, TypeParam tp | - tp = ty.(TypeParamTypeParameter).getTypeParam() and - path = prefix.append(suffix) - | - // For type parameters of the `impl` block we must resolve their - // instantiation from the path. For instance, for `impl for Foo` - // and the path `Foo::bar` we must resolve `A` to `i64`. - exists(TypePath pathToTp | - pathToTp = getPathToImplSelfTypeParam(tp) and - result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(suffix)) - ) - or - // For type parameters of the function we must resolve their - // instantiation from the path. For instance, for `fn bar(a: A) -> A` - // and the path `bar`, we must resolve `A` to `i64`. - result = getCallExprTypeArgument(ce, TTypeParamTypeArgumentPosition(tp), suffix) - ) - or - not ty instanceof TypeParameter and - result = ty and - path = prefix - ) - } - private Type inferCertainStructExprType(StructExpr se, TypePath path) { result = se.getPath().(TypeMention).getTypeAt(path) } @@ -834,8 +832,6 @@ private module CertainTypeInferenceInput { Type inferCertainTypeInput(AstNode n, TypePath path) { result = inferFunctionBodyType(n, path) or - result = inferCertainCallExprType(n, path) - or result = inferLiteralType(n, path, true) or result = inferRefPatType(n) and @@ -2612,6 +2608,11 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput FunctionDeclaration getFunction() { result = f } + predicate isFunction(ImplOrTraitItemNodeOption i_, Function f_) { + i_ = i and + f_ = f + } + predicate isAssocFunction(ImplOrTraitItemNode i_, Function f_) { i_ = i.asSome() and f_ = f diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 5dad50dac2ee..4d5f078505c3 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2212,8 +2212,8 @@ module Make1 Input1> { class AssignExpr extends Assignment; /** A parenthesized expression. */ - class ParenExpr extends AstNode { - AstNode getExpr(); + class ParenExpr extends Expr { + Expr getExpr(); } /** A variable, for example a local variable or a field. */ @@ -2248,6 +2248,29 @@ module Make1 Input1> { AstNode getRightOperand(); } + class CallTarget { + TypeParameter getTypeParameter(TypeParameterPosition ppos); + + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp); + + Type getReturnType(TypePath path); + + Type getParameterType(int index, TypePath path); + + /** Gets a textual representation of this element. */ + string toString(); + + /** Gets the location of this element. */ + Location getLocation(); + } + + class Call extends Expr { + Type getTypeArgument(TypeArgumentPosition apos, TypePath path); + + /** Gets the target of this call. */ + CallTarget getTargetCertain(); + } + /** * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. */ @@ -2332,6 +2355,32 @@ module Make1 Input1> { path.isEmpty() } + pragma[nomagic] + private Type getCertainCallExprType(Call call, TypePath path) { + forex(CallTarget target | target = call.getTargetCertain() | + result = target.getReturnType(path) + ) + } + + pragma[nomagic] + private Type inferCertainCallExprType(Call call, TypePath path) { + exists(Type ty, TypePath prefix | ty = getCertainCallExprType(call, prefix) | + exists( + CallTarget target, TypePath suffix, TypeParameterPosition tppos, + TypeArgumentPosition tapos + | + ty = target.getTypeParameter(tppos) and + path = prefix.append(suffix) and + result = call.getTypeArgument(tapos, suffix) and + typeArgumentParameterPositionMatch(tapos, tppos) + ) + or + not ty instanceof TypeParameter and + result = ty and + path = prefix + ) + } + /** Gets the inferred certain type of `n` at `path`. */ cached Type inferCertainType(AstNode n, TypePath path) { @@ -2344,6 +2393,8 @@ module Make1 Input1> { or result = inferLogicalOperationType(n, path) or + result = inferCertainCallExprType(n, path) + or infersCertainTypeAt(n, path, result.getATypeParameter()) } diff --git a/swift/ql/test/library-tests/type-inference/type-inference.expected b/swift/ql/test/library-tests/type-inference/type-inference.expected index e69de29bb2d1..79a3599aa1f3 100644 --- a/swift/ql/test/library-tests/type-inference/type-inference.expected +++ b/swift/ql/test/library-tests/type-inference/type-inference.expected @@ -0,0 +1,11 @@ +| context.swift:17:10:17:10 | C.init() | Unexpected result: target=init() | +| context.swift:17:10:17:12 | call to C.init() | Unexpected result: target=init() | +| context.swift:17:14:18:1 | // $ type=C\n | Missing result: type=C | +| context.swift:25:11:25:11 | A.init() | Unexpected result: target=init() | +| context.swift:25:11:25:13 | call to A.init() | Unexpected result: target=init() | +| context.swift:26:19:26:19 | D.init() | Unexpected result: target=init() | +| context.swift:26:19:26:21 | call to D.init() | Unexpected result: target=init() | +| context.swift:26:25:26:25 | B.init() | Unexpected result: target=init() | +| context.swift:26:25:26:27 | call to B.init() | Unexpected result: target=init() | +| file://:0:0:0:0 | A.init() | Unexpected result: target=init() | +| file://:0:0:0:0 | call to A.init() | Unexpected result: target=init() | From 441ebad711243af9056c272a14c0ebd2550f41c8 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Wed, 6 May 2026 15:55:24 +0200 Subject: [PATCH 07/12] wip7 --- .../internal/typeinference/TypeInference.qll | 161 +++++------------- .../type-inference/type-inference.expected | 5 + .../typeinference/internal/TypeInference.qll | 154 +++++++++++++++-- 3 files changed, 191 insertions(+), 129 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 4d2df25c7fe9..a4ba95f30253 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -286,6 +286,8 @@ private module Input3 implements InputSig3 { (exists(resolveTupleFieldExpr(_, _)) implies any()) } + predicate inferType = M3::inferType/2; + class BoolType extends DataType { BoolType() { this.getTypeItem() instanceof Builtins::Bool } } @@ -366,7 +368,11 @@ private module Input3 implements InputSig3 { override AstNode getRightOperand() { result = this.getInitializer() } } - class CallTarget extends FunctionCallMatchingInput::Declaration { + class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment; + + class TypePosition = FunctionPosition; + + class Callable extends FunctionCallMatchingInput::Declaration { TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) { result = tp.(TypeParamTypeParameter) @@ -374,17 +380,6 @@ private module Input3 implements InputSig3 { .getAdditionalTypeBound(this.getFunction(), _) .getTypeRepr() } - - Type getReturnType(TypePath path) { - exists(FunctionPosition pos | - pos.isReturn() and - result = super.getDeclaredType(pos, path) - ) - } - - Type getParameterType(int index, TypePath path) { - none() // todo - } } class Call extends Expr instanceof FunctionCallMatchingInput::Access { @@ -392,8 +387,10 @@ private module Input3 implements InputSig3 { result = super.getTypeArgument(apos, path) } + AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) } + /** Gets the target of this call. */ - CallTarget getTargetCertain() { + Callable getTargetCertain() { exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p | result.isFunction(i, f) and p = CallExprImpl::getFunctionPath(this) and @@ -401,6 +398,24 @@ private module Input3 implements InputSig3 { f.isDirectlyFor(i) ) } + + Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) } + } + + bindingset[derefChainBorrow] + Type inferCallTypeIn(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) { + result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path) + } + + Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path) { + result = inferFunctionCallTypeNonSelf(n, pos, path) + or + exists(FunctionCallMatchingInput::Access a | + result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and + if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() + then not path.isEmpty() + else any() + ) } predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { @@ -553,8 +568,6 @@ private module Input3 implements InputSig3 { Type inferTypeInput(AstNode n, TypePath path) { result = inferAssignmentOperationType(n, path) or - result = inferFunctionCallType(n, path) - or result = inferConstructionType(n, path) or result = inferOperationType(n, path) @@ -1094,53 +1107,6 @@ private module ContextTyping { ) } } - - pragma[nomagic] - private predicate hasUnknownTypeAt(AstNode n, TypePath path) { - inferType(n, path) = TUnknownType() - } - - pragma[nomagic] - private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } - - newtype FunctionPositionKind = - SelfKind() or - ReturnKind() or - PositionalKind() - - signature Type inferCallTypeSig(AstNode n, FunctionPositionKind kind, TypePath path); - - /** - * Given a predicate `inferCallType` for inferring the type of a call at a given - * position, this module exposes the predicate `check`, which wraps the input - * predicate and checks that types are only propagated into arguments when they - * are context-typed. - */ - module CheckContextTyping { - pragma[nomagic] - private Type inferCallNonReturnType( - AstNode n, FunctionPositionKind kind, TypePath prefix, TypePath path - ) { - result = inferCallType(n, kind, path) and - hasUnknownType(n) and - kind != ReturnKind() and - prefix = path.getAPrefix() - } - - pragma[nomagic] - Type check(AstNode n, TypePath path) { - result = inferCallType(n, ReturnKind(), path) - or - exists(FunctionPositionKind kind, TypePath prefix | - result = inferCallNonReturnType(n, kind, prefix, path) and - hasUnknownTypeAt(n, prefix) - | - // Never propagate type information directly into the receiver, since its type - // must already have been known in order to resolve the call - if kind = SelfKind() then not prefix.isEmpty() else any() - ) - } - } } /** @@ -2836,22 +2802,20 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } } -private module FunctionCallMatching = MatchingWithEnvironment; - pragma[nomagic] private Type inferFunctionCallType0( FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path ) { exists(TypePath path0 | - n = call.getNodeAt(pos) and exists(string derefChainBorrow | FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) | - result = FunctionCallMatching::inferAccessType(call, derefChainBorrow, pos, path0) - or + n = call.getNodeAt(pos) and call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and result = TUnknownType() + or + result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0) ) | if @@ -2919,31 +2883,6 @@ private Type inferFunctionCallTypeSelf( ) } -private Type inferFunctionCallTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(FunctionPosition pos | - result = inferFunctionCallTypeNonSelf(n, pos, path) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - ) - or - exists(FunctionCallMatchingInput::Access a | - result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and - if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() - then kind = ContextTyping::SelfKind() - else kind = ContextTyping::PositionalKind() - ) -} - -/** - * Gets the type of `n` at `path`, where `n` is either a function call or an - * argument/receiver of a function call. - */ -private predicate inferFunctionCallType = - ContextTyping::CheckContextTyping::check/2; - abstract private class Constructor extends Addressable { final TypeParameter getTypeParameter(TypeParameterPosition ppos) { typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos) @@ -3102,15 +3041,8 @@ private module ConstructionMatchingInput implements MatchingInputSig { private module ConstructionMatching = Matching; pragma[nomagic] -private Type inferConstructionTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(ConstructionMatchingInput::Access a, FunctionPosition pos | - n = a.getNodeAt(pos) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - | +private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { + exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) | result = ConstructionMatching::inferAccessType(a, pos, path) or a.hasUnknownTypeAt(pos, path) and @@ -3119,7 +3051,7 @@ private Type inferConstructionTypePreCheck( } private predicate inferConstructionType = - ContextTyping::CheckContextTyping::check/2; + CheckContextTyping::check/2; /** * A matching configuration for resolving types of operations like `a + b`. @@ -3184,23 +3116,15 @@ private module OperationMatchingInput implements MatchingInputSig { private module OperationMatching = Matching; pragma[nomagic] -private Type inferOperationTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(OperationMatchingInput::Access a, FunctionPosition pos | +private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { + exists(OperationMatchingInput::Access a | n = a.getNodeAt(pos) and result = OperationMatching::inferAccessType(a, pos, path) and - if pos.asPosition() = 0 - then kind = ContextTyping::SelfKind() - else - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() + if pos.asPosition() = 0 then not path.isEmpty() else any() ) } -private predicate inferOperationType = - ContextTyping::CheckContextTyping::check/2; +private predicate inferOperationType = CheckContextTyping::check/2; pragma[nomagic] private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) { @@ -3815,11 +3739,10 @@ private module Debug { t = self.getTypeAt(path) } - predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { - n = getRelevantLocatable() and - t = inferFunctionCallType(n, path) - } - + // predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { + // n = getRelevantLocatable() and + // t = inferFunctionCallType(n, path) + // } predicate debugInferConstructionType(AstNode n, TypePath path, Type t) { n = getRelevantLocatable() and t = inferConstructionType(n, path) diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 3344fc45f74f..f30a2c064d12 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -10252,6 +10252,7 @@ inferType | main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option | | main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & | +| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & | | main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 | @@ -11600,6 +11601,8 @@ inferType | main.rs:2221:18:2221:21 | true | | {EXTERNAL LOCATION} | bool | | main.rs:2223:9:2223:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | +| main.rs:2223:9:2223:15 | S(...) | T | main.rs:2107:5:2107:19 | S | +| main.rs:2223:9:2223:15 | S(...) | T.T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | main.rs:2107:5:2107:19 | S | @@ -11618,6 +11621,8 @@ inferType | main.rs:2224:24:2224:27 | 3i64 | | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | +| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | & | +| main.rs:2225:9:2225:15 | S(...) | T.TRef | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:29 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:29 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2225:11:2225:14 | 1i64 | | {EXTERNAL LOCATION} | i64 | diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 4d5f078505c3..aeb6cc6ed3b1 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2151,6 +2151,13 @@ module Make1 Input1> { */ default predicate cachedStageRevRef() { none() } + /** + * Point this predicate to the `inferType` predicate in the output of this module. + * + * Needed to be able to refer to `inferType` in default signature implementations. + */ + Type inferType(AstNode n, TypePath path); + /** A boolean type. */ class BoolType extends Type; @@ -2248,29 +2255,64 @@ module Make1 Input1> { AstNode getRightOperand(); } - class CallTarget { + /** + * A position where a callable can have a declared type. + */ + class TypePosition { + /** Holds if this position represents the return type of a callable. */ + predicate isReturn(); + + /** Gets a textual representation of this position. */ + string toString(); + } + + /** A context needed to resolve calls. */ + bindingset[this] + class CallResolutionContext { + /** Gets a textual representation of this context. */ + bindingset[this] + string toString(); + } + + /** A callable. */ + class Callable { TypeParameter getTypeParameter(TypeParameterPosition ppos); TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp); - Type getReturnType(TypePath path); + /* Gets the declared type of this callable at `path` for position `pos`. */ + Type getDeclaredType(TypePosition pos, TypePath path); - Type getParameterType(int index, TypePath path); - - /** Gets a textual representation of this element. */ + /** Gets a textual representation of this callable. */ string toString(); - /** Gets the location of this element. */ + /** Gets the location of this callable. */ Location getLocation(); } class Call extends Expr { Type getTypeArgument(TypeArgumentPosition apos, TypePath path); + AstNode getNodeAt(TypePosition pos); + /** Gets the target of this call. */ - CallTarget getTargetCertain(); + Callable getTargetCertain(); + + /** Gets the target of this call. */ + Callable getTarget(CallResolutionContext ctx); } + /** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */ + bindingset[ctx] + default Type inferCallTypeIn( + Call call, CallResolutionContext ctx, TypePosition pos, TypePath path + ) { + result = inferType(call.getNodeAt(pos), path) and + exists(ctx) + } + + Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path); + /** * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. */ @@ -2357,8 +2399,11 @@ module Make1 Input1> { pragma[nomagic] private Type getCertainCallExprType(Call call, TypePath path) { - forex(CallTarget target | target = call.getTargetCertain() | - result = target.getReturnType(path) + exists(TypePosition ret | + ret.isReturn() and + forex(Callable target | target = call.getTargetCertain() | + result = target.getDeclaredType(ret, path) + ) ) } @@ -2366,7 +2411,7 @@ module Make1 Input1> { private Type inferCertainCallExprType(Call call, TypePath path) { exists(Type ty, TypePath prefix | ty = getCertainCallExprType(call, prefix) | exists( - CallTarget target, TypePath suffix, TypeParameterPosition tppos, + Callable target, TypePath suffix, TypeParameterPosition tppos, TypeArgumentPosition tapos | ty = target.getTypeParameter(tppos) and @@ -2524,15 +2569,104 @@ module Make1 Input1> { ( result = inferTypeEquality(n, path) or + result = CheckContextTyping::check(n, path) + or result = inferTypeInput(n, path) ) } + private module TypePositionMatchingInput { + class DeclarationPosition = TypePosition; + + class AccessPosition = DeclarationPosition; + + predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) { + apos = dpos + } + } + + /** + * A matching configuration for resolving types of calls. + */ + private module CallMatchingInput implements MatchingWithEnvironmentInputSig { + import TypePositionMatchingInput + + class Declaration = Callable; + + bindingset[decl] + TypeMention getATypeParameterConstraint(TypeParameter tp, Declaration decl) { + result = Input2::getATypeParameterConstraint(tp) and + exists(decl) + or + result = decl.getAdditionalTypeParameterConstraint(tp) + } + + class AccessEnvironment = CallResolutionContext; + + final private class CallFinal = Call; + + class Access extends CallFinal { + bindingset[e] + Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) { + result = inferCallTypeIn(this, e, apos, path) + } + } + } + + private module CallMatching = MatchingWithEnvironment; + + pragma[nomagic] + Type inferCallTypeOut( + Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path + ) { + n = call.getNodeAt(pos) and + result = CallMatching::inferAccessType(call, ctx, pos, path) + } + + pragma[nomagic] + private predicate hasUnknownTypeAt(AstNode n, TypePath path) { + inferType(n, path) instanceof UnknownType + } + + pragma[nomagic] + private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } + + signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path); + + /** + * Given a predicate `inferCallType` for inferring the type of a call at a given + * position, this module exposes the predicate `check`, which wraps the input + * predicate and checks that types are only propagated into arguments when they + * are context-typed. + */ + module CheckContextTyping { + pragma[nomagic] + private Type inferCallNonReturnType(AstNode n, TypePath prefix, TypePath path) { + exists(TypePosition pos | + result = inferCallType(n, pos, path) and + hasUnknownType(n) and + not pos.isReturn() and + prefix = path.getAPrefix() + ) + } + + pragma[nomagic] + Type check(AstNode n, TypePath path) { + result = inferCallType(n, any(TypePosition pos | pos.isReturn()), path) + or + exists(TypePath prefix | + result = inferCallNonReturnType(n, prefix, path) and + hasUnknownTypeAt(n, prefix) + ) + } + } + /** * Gets the inferred root type of `n`, if any. */ Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } + // todo: consistency checks /** The cached stage of type inference. */ cached module CachedStage { From 60ceeab984bb2bbe3892b1ef8c275c02637da743 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Thu, 7 May 2026 09:03:20 +0200 Subject: [PATCH 08/12] wip8 --- .../internal/typeinference/TypeInference.qll | 111 +++++++++--------- .../typeinference/internal/TypeInference.qll | 22 ++++ 2 files changed, 79 insertions(+), 54 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index a4ba95f30253..1b24e8210408 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -312,6 +312,18 @@ private module Input3 implements InputSig3 { class Expr = Rust::Expr; + class SwitchExpr extends Rust::MatchExpr { + Expr getExpr() { result = this.getScrutinee() } + + Case getCase(int index) { result = this.getArm(index) } + } + + class Case extends Rust::MatchArm { + AstNode getAPattern() { result = this.getPat() } + + Expr getBody() { result = this.getExpr() } + } + class ConditionalExpr extends IfExpr { Expr getThen() { result = super.getThen() } } @@ -447,7 +459,46 @@ private module Input3 implements InputSig3 { ) } - predicate inferCertainTypeInput = CertainTypeInferenceInput::inferCertainTypeInput/2; + Type inferCertainTypeInput(AstNode n, TypePath path) { + result = inferFunctionBodyType(n, path) + or + result = inferLiteralType(n, path, true) + or + result = inferRefPatType(n) and + path.isEmpty() + or + result = inferRefExprType(n) and + path.isEmpty() + or + result = inferCertainStructExprType(n, path) + or + result = inferCertainStructPatType(n, path) + or + result = inferRangeExprType(n) and + path.isEmpty() + or + result = inferTupleRootType(n) and + path.isEmpty() + or + result = inferBlockExprType(n, path) + or + result = inferArrayExprType(n) and + path.isEmpty() + or + result = inferCastExprType(n, path) + or + exprHasUnitType(n) and + path.isEmpty() and + result instanceof UnitType + or + isPanicMacroCall(n) and + path.isEmpty() and + result instanceof NeverType + or + n instanceof ClosureExpr and + path.isEmpty() and + result = closureRootType() + } predicate typeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { prefix1.isEmpty() and @@ -828,60 +879,12 @@ private TypePath closureParameterPath(int arity, int index) { TypePath::singleton(getTupleTypeParameter(arity, index))) } -/** Module for inferring certain type information. */ -private module CertainTypeInferenceInput { - private Type inferCertainStructExprType(StructExpr se, TypePath path) { - result = se.getPath().(TypeMention).getTypeAt(path) - } - - private Type inferCertainStructPatType(StructPat sp, TypePath path) { - result = sp.getPath().(TypeMention).getTypeAt(path) - } +private Type inferCertainStructExprType(StructExpr se, TypePath path) { + result = se.getPath().(TypeMention).getTypeAt(path) +} - /** - * Holds if `n` has complete and certain type information and if `n` has the - * resulting type at `path`. - */ - Type inferCertainTypeInput(AstNode n, TypePath path) { - result = inferFunctionBodyType(n, path) - or - result = inferLiteralType(n, path, true) - or - result = inferRefPatType(n) and - path.isEmpty() - or - result = inferRefExprType(n) and - path.isEmpty() - or - result = inferCertainStructExprType(n, path) - or - result = inferCertainStructPatType(n, path) - or - result = inferRangeExprType(n) and - path.isEmpty() - or - result = inferTupleRootType(n) and - path.isEmpty() - or - result = inferBlockExprType(n, path) - or - result = inferArrayExprType(n) and - path.isEmpty() - or - result = inferCastExprType(n, path) - or - exprHasUnitType(n) and - path.isEmpty() and - result instanceof UnitType - or - isPanicMacroCall(n) and - path.isEmpty() and - result instanceof NeverType - or - n instanceof ClosureExpr and - path.isEmpty() and - result = closureRootType() - } +private Type inferCertainStructPatType(StructPat sp, TypePath path) { + result = sp.getPath().(TypeMention).getTypeAt(path) } private Type inferAssignmentOperationType(AstNode n, TypePath path) { diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index aeb6cc6ed3b1..c3adee01c7f8 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2176,6 +2176,28 @@ module Make1 Input1> { /** An expression. */ class Expr extends AstNode; + /** + * A switch expression. + */ + class SwitchExpr extends Expr { + /** + * Gets the expression being switched on. + */ + Expr getExpr(); + + /** Gets the case at the specified (zero-based) `index`. */ + Case getCase(int index); + } + + /** A case in a switch expression. */ + class Case extends AstNode { + /** Gets a pattern being matched by this case. */ + AstNode getAPattern(); + + /** Gets the body of this case. */ + Expr getBody(); + } + /** A ternary conditional expression. */ class ConditionalExpr extends Expr { /** Gets the condition of this expression. */ From 0e414a557e8ed9828cb0617e5ff92021d870f0ed Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Thu, 7 May 2026 10:59:57 +0200 Subject: [PATCH 09/12] wip9 --- .../internal/typeinference/TypeInference.qll | 100 ++++---- .../type-inference/type-inference.expected | 24 -- .../type-inference/type-inference.ql | 4 +- .../typeinference/internal/TypeInference.qll | 220 ++++++++++++------ 4 files changed, 203 insertions(+), 145 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 1b24e8210408..02b98f9990a3 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -312,7 +312,7 @@ private module Input3 implements InputSig3 { class Expr = Rust::Expr; - class SwitchExpr extends Rust::MatchExpr { + class Switch extends Rust::MatchExpr { Expr getExpr() { result = this.getScrutinee() } Case getCase(int index) { result = this.getArm(index) } @@ -321,7 +321,7 @@ private module Input3 implements InputSig3 { class Case extends Rust::MatchArm { AstNode getAPattern() { result = this.getPat() } - Expr getBody() { result = this.getExpr() } + AstNode getBody() { result = this.getExpr() } } class ConditionalExpr extends IfExpr { @@ -430,7 +430,7 @@ private module Input3 implements InputSig3 { ) } - predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + predicate inferStepSymmetricCertain(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { n1 = any(IdentPat ip | n2 = ip.getName() and @@ -459,7 +459,7 @@ private module Input3 implements InputSig3 { ) } - Type inferCertainTypeInput(AstNode n, TypePath path) { + Type inferTypeCertainInput(AstNode n, TypePath path) { result = inferFunctionBodyType(n, path) or result = inferLiteralType(n, path, true) @@ -500,15 +500,10 @@ private module Input3 implements InputSig3 { result = closureRootType() } - predicate typeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + predicate inferStepSymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { prefix1.isEmpty() and prefix2.isEmpty() and ( - exists(MatchExpr me | - n1 = me.getScrutinee() and - n2 = me.getAnArm().getPat() - ) - or n1 = n2.(OrPat).getAPat() or n1 = n2.(ParenPat).getPat() @@ -572,7 +567,7 @@ private module Input3 implements InputSig3 { prefix2.isEmpty() } - predicate typeEqualityAsymmetricInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + predicate inferStep(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { // When `n2` is `*n1` propagate type information from a raw pointer type // parameter at `n1`. The other direction is handled in // `inferDereferencedExprPtrType`. @@ -596,23 +591,21 @@ private module Input3 implements InputSig3 { * * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound */ - predicate parentChildType(AstNode parent, AstNode child, TypePath prefix) { - child = parent.(IfExpr).getABranch() and - prefix.isEmpty() - or - parent = any(MatchExpr me | child = me.getAnArm().getExpr()) and - prefix.isEmpty() - or - parent = any(ArrayListExpr ale | child = ale.getAnExpr()) and - prefix = TypePath::singleton(getArrayTypeParameter()) - or - bodyReturns(parent, child) and - prefix.isEmpty() - or - exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) + predicate inferLubStep(AstNode child, TypePath path1, AstNode parent, TypePath prefix) { + path1.isEmpty() and + ( + parent = any(ArrayListExpr ale | child = ale.getAnExpr()) and + prefix = TypePath::singleton(getArrayTypeParameter()) + or + bodyReturns(parent, child) and + prefix.isEmpty() + or + exists(Struct s | + child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and + prefix = + TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and + s = getRangeType(parent) + ) ) } @@ -652,14 +645,14 @@ import M3 module Consistency { import M2::Consistency - private Type inferCertainTypeAdj(AstNode n, TypePath path) { - result = CertainTypeInference::inferCertainType(n, path) and + private Type inferTypeCertainAdj(AstNode n, TypePath path) { + result = inferTypeCertain(n, path) and not result = TNeverType() } predicate nonUniqueCertainType(AstNode n, TypePath path, Type t) { - strictcount(inferCertainTypeAdj(n, path)) > 1 and - t = inferCertainTypeAdj(n, path) and + strictcount(inferTypeCertainAdj(n, path)) > 1 and + t = inferTypeCertainAdj(n, path) and // Suppress the inconsistency if `n` is a self parameter and the type // mention for the self type has multiple types for a path. not exists(ImplItemNode impl, TypePath selfTypePath | @@ -922,8 +915,8 @@ private predicate bodyReturns(Expr body, Expr e) { ) } -private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { - inferType(n, path) = TUnknownType() and +pragma[nomagic] +private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePath prefix) { // Normally, these are coercion sites, but in case a type is unknown we // allow for type information to flow from the type annotation. exists(TypeMention tm | result = tm.getTypeAt(path) | @@ -932,6 +925,14 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr() or tm = getReturnTypeMention(any(Function f | n = f.getBody())) + ) and + prefix = path.getAPrefix() +} + +private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { + exists(TypePath prefix | + result = inferUnknownTypeFromAnnotationCand(n, path, prefix) and + hasUnknownTypeAt(n, prefix) ) } @@ -3019,7 +3020,7 @@ private module ConstructionMatchingInput implements MatchingInputSig { or exists(TypePath suffix | suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and - result = CertainTypeInference::inferCertainType(this, suffix) + result = inferTypeCertain(this, suffix) ) } @@ -3628,6 +3629,15 @@ private Type inferForLoopExprType(AstNode n, TypePath path) { ) } +pragma[nomagic] +private Type inferClosureExprBodyTypeCand(AstNode n, TypePath path, TypePath prefix) { + exists(ClosureExpr ce | + n = ce.getClosureBody() and + result = inferType(ce, closureReturnPath().appendInverse(path)) and + prefix = path.getAPrefix() + ) +} + pragma[nomagic] private Type inferClosureExprType(AstNode n, TypePath path) { exists(ClosureExpr ce | @@ -3650,6 +3660,11 @@ private Type inferClosureExprType(AstNode n, TypePath path) { path.isEmpty() ) ) + or + exists(TypePath prefix | + result = inferClosureExprBodyTypeCand(n, path, prefix) and + hasUnknownTypeAt(n, prefix) + ) } pragma[nomagic] @@ -3716,7 +3731,7 @@ private module Debug { exists(string filepath, int startline, int startcolumn, int endline, int endcolumn | result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and filepath.matches("%/main.rs") and - startline = 103 + startline = 1102 ) } @@ -3756,11 +3771,10 @@ private module Debug { tm.getTypeAt(path) = type } - Type debugInferAnnotatedType(AstNode n, TypePath path) { - n = getRelevantLocatable() and - result = CertainTypeInference::inferAnnotatedType(n, path) - } - + // Type debugInferAnnotatedType(AstNode n, TypePath path) { + // n = getRelevantLocatable() and + // result = inferAnnotatedType(n, path) + // } pragma[nomagic] private int countTypesAtPath(AstNode n, TypePath path, Type t) { t = inferType(n, path) and @@ -3809,9 +3823,9 @@ private module Debug { c = max(countTypePaths(_, _, _)) } - Type debugInferCertainType(AstNode n, TypePath path) { + Type debuginferTypeCertain(AstNode n, TypePath path) { n = getRelevantLocatable() and - result = CertainTypeInference::inferCertainType(n, path) + result = inferTypeCertain(n, path) } Type debugInferCertainNonUniqueType(AstNode n, TypePath path) { diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index f30a2c064d12..9075fc17f3eb 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -8961,10 +8961,8 @@ inferType | main.rs:826:16:826:16 | 3 | | {EXTERNAL LOCATION} | i32 | | main.rs:826:16:826:20 | ... > ... | | {EXTERNAL LOCATION} | bool | | main.rs:826:20:826:20 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:826:22:828:13 | { ... } | | main.rs:820:20:820:22 | Tr2 | | main.rs:827:17:827:20 | self | | {EXTERNAL LOCATION} | & | | main.rs:827:17:827:20 | self | TRef | main.rs:820:5:832:5 | Self [trait MyTrait2] | -| main.rs:827:17:827:25 | self.m1() | | main.rs:820:20:820:22 | Tr2 | | main.rs:828:20:830:13 | { ... } | | main.rs:820:20:820:22 | Tr2 | | main.rs:829:17:829:31 | ...::m1(...) | | main.rs:820:20:820:22 | Tr2 | | main.rs:829:26:829:30 | * ... | | main.rs:820:5:832:5 | Self [trait MyTrait2] | @@ -11482,13 +11480,9 @@ inferType | main.rs:2099:13:2103:13 | if value {...} else {...} | | {EXTERNAL LOCATION} | i64 | | main.rs:2099:16:2099:20 | value | | {EXTERNAL LOCATION} | bool | | main.rs:2099:22:2101:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2099:22:2101:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2100:17:2100:17 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2100:17:2100:17 | 1 | | {EXTERNAL LOCATION} | i64 | | main.rs:2101:20:2103:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2101:20:2103:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2102:17:2102:17 | 0 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2102:17:2102:17 | 0 | | {EXTERNAL LOCATION} | i64 | | main.rs:2113:19:2113:22 | SelfParam | | main.rs:2107:5:2107:19 | S | | main.rs:2113:19:2113:22 | SelfParam | T | main.rs:2109:10:2109:17 | T | | main.rs:2113:25:2113:29 | other | | main.rs:2107:5:2107:19 | S | @@ -11543,13 +11537,9 @@ inferType | main.rs:2154:13:2158:13 | if value {...} else {...} | | {EXTERNAL LOCATION} | i64 | | main.rs:2154:16:2154:20 | value | | {EXTERNAL LOCATION} | bool | | main.rs:2154:22:2156:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2154:22:2156:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2155:17:2155:17 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2155:17:2155:17 | 1 | | {EXTERNAL LOCATION} | i64 | | main.rs:2156:20:2158:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2156:20:2158:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2157:17:2157:17 | 0 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2157:17:2157:17 | 0 | | {EXTERNAL LOCATION} | i64 | | main.rs:2164:21:2164:25 | value | | main.rs:2162:19:2162:19 | T | | main.rs:2164:31:2164:31 | x | | main.rs:2162:5:2165:5 | Self [trait MyFrom2] | | main.rs:2169:21:2169:25 | value | | {EXTERNAL LOCATION} | i64 | @@ -11715,9 +11705,7 @@ inferType | main.rs:2265:21:2265:31 | [...] | TArray | {EXTERNAL LOCATION} | u8 | | main.rs:2265:22:2265:24 | 1u8 | | {EXTERNAL LOCATION} | u8 | | main.rs:2265:27:2265:27 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2265:27:2265:27 | 2 | | {EXTERNAL LOCATION} | u8 | | main.rs:2265:30:2265:30 | 3 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2265:30:2265:30 | 3 | | {EXTERNAL LOCATION} | u8 | | main.rs:2266:9:2266:25 | for ... in ... { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2266:13:2266:13 | u | | {EXTERNAL LOCATION} | i32 | | main.rs:2266:13:2266:13 | u | | {EXTERNAL LOCATION} | u8 | @@ -11743,11 +11731,8 @@ inferType | main.rs:2271:31:2271:39 | [...] | TArray | {EXTERNAL LOCATION} | i32 | | main.rs:2271:31:2271:39 | [...] | TArray | {EXTERNAL LOCATION} | u32 | | main.rs:2271:32:2271:32 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:32:2271:32 | 1 | | {EXTERNAL LOCATION} | u32 | | main.rs:2271:35:2271:35 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:35:2271:35 | 2 | | {EXTERNAL LOCATION} | u32 | | main.rs:2271:38:2271:38 | 3 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:38:2271:38 | 3 | | {EXTERNAL LOCATION} | u32 | | main.rs:2272:9:2272:25 | for ... in ... { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2272:13:2272:13 | u | | {EXTERNAL LOCATION} | u32 | | main.rs:2272:18:2272:22 | vals3 | | {EXTERNAL LOCATION} | [;] | @@ -11888,7 +11873,6 @@ inferType | main.rs:2308:19:2308:25 | 0u8..10 | Idx | {EXTERNAL LOCATION} | i32 | | main.rs:2308:19:2308:25 | 0u8..10 | Idx | {EXTERNAL LOCATION} | u8 | | main.rs:2308:24:2308:25 | 10 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2308:24:2308:25 | 10 | | {EXTERNAL LOCATION} | u8 | | main.rs:2308:28:2308:29 | { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2309:13:2309:17 | range | | {EXTERNAL LOCATION} | Range | | main.rs:2309:13:2309:17 | range | Idx | {EXTERNAL LOCATION} | i32 | @@ -12641,11 +12625,9 @@ inferType | main.rs:2583:12:2583:12 | b | | {EXTERNAL LOCATION} | bool | | main.rs:2583:14:2586:9 | { ... } | | {EXTERNAL LOCATION} | Box | | main.rs:2583:14:2586:9 | { ... } | A | {EXTERNAL LOCATION} | Global | -| main.rs:2583:14:2586:9 | { ... } | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2583:14:2586:9 | { ... } | T | main.rs:2551:5:2552:19 | S | | main.rs:2583:14:2586:9 | { ... } | T.T | main.rs:2551:5:2552:19 | S | | main.rs:2583:14:2586:9 | { ... } | T.T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2583:14:2586:9 | { ... } | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2584:17:2584:17 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2584:17:2584:17 | x | T | main.rs:2551:5:2552:19 | S | | main.rs:2584:17:2584:17 | x | T.T | {EXTERNAL LOCATION} | i32 | @@ -12656,26 +12638,20 @@ inferType | main.rs:2584:21:2584:26 | x.m2() | T.T | {EXTERNAL LOCATION} | i32 | | main.rs:2585:13:2585:23 | ...::new(...) | | {EXTERNAL LOCATION} | Box | | main.rs:2585:13:2585:23 | ...::new(...) | A | {EXTERNAL LOCATION} | Global | -| main.rs:2585:13:2585:23 | ...::new(...) | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2585:13:2585:23 | ...::new(...) | T | main.rs:2551:5:2552:19 | S | | main.rs:2585:13:2585:23 | ...::new(...) | T.T | main.rs:2551:5:2552:19 | S | | main.rs:2585:13:2585:23 | ...::new(...) | T.T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2585:13:2585:23 | ...::new(...) | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2585:22:2585:22 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2585:22:2585:22 | x | T | main.rs:2551:5:2552:19 | S | | main.rs:2585:22:2585:22 | x | T.T | {EXTERNAL LOCATION} | i32 | | main.rs:2586:16:2588:9 | { ... } | | {EXTERNAL LOCATION} | Box | | main.rs:2586:16:2588:9 | { ... } | A | {EXTERNAL LOCATION} | Global | -| main.rs:2586:16:2588:9 | { ... } | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2586:16:2588:9 | { ... } | T | main.rs:2551:5:2552:19 | S | | main.rs:2586:16:2588:9 | { ... } | T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2586:16:2588:9 | { ... } | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2587:13:2587:23 | ...::new(...) | | {EXTERNAL LOCATION} | Box | | main.rs:2587:13:2587:23 | ...::new(...) | A | {EXTERNAL LOCATION} | Global | -| main.rs:2587:13:2587:23 | ...::new(...) | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2587:13:2587:23 | ...::new(...) | T | main.rs:2551:5:2552:19 | S | | main.rs:2587:13:2587:23 | ...::new(...) | T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2587:13:2587:23 | ...::new(...) | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2587:22:2587:22 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2587:22:2587:22 | x | T | {EXTERNAL LOCATION} | i32 | | main.rs:2593:22:2597:5 | { ... } | | {EXTERNAL LOCATION} | () | diff --git a/rust/ql/test/library-tests/type-inference/type-inference.ql b/rust/ql/test/library-tests/type-inference/type-inference.ql index 8dcc34ad8001..374884ec4574 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.ql +++ b/rust/ql/test/library-tests/type-inference/type-inference.ql @@ -12,7 +12,7 @@ private predicate relevantNode(AstNode n) { } query predicate inferCertainType(AstNode n, TypePath path, Type t) { - t = TypeInference::CertainTypeInference::inferCertainType(n, path) and + t = TypeInference::inferTypeCertain(n, path) and t != TUnknownType() and relevantNode(n) } @@ -70,7 +70,7 @@ module TypeTest implements TestSig { ( tag = "type" or - t = TypeInference::CertainTypeInference::inferCertainType(n, path) and + t = TypeInference::inferTypeCertain(n, path) and tag = "certainType" ) and location = n.getLocation() and diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index c3adee01c7f8..0b10f4b1d98c 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2177,9 +2177,9 @@ module Make1 Input1> { class Expr extends AstNode; /** - * A switch expression. + * A switch. */ - class SwitchExpr extends Expr { + class Switch extends AstNode { /** * Gets the expression being switched on. */ @@ -2189,13 +2189,13 @@ module Make1 Input1> { Case getCase(int index); } - /** A case in a switch expression. */ + /** A case in a switch. */ class Case extends AstNode { /** Gets a pattern being matched by this case. */ AstNode getAPattern(); /** Gets the body of this case. */ - Expr getBody(); + AstNode getBody(); } /** A ternary conditional expression. */ @@ -2336,25 +2336,41 @@ module Make1 Input1> { Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path); /** - * Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal. + * Holds if `n1` having certain type `t` at `path1` implies that `n2` has + * certain type `t` at `path2`, but not necessarily the other way around. */ - predicate certainTypeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + default predicate inferStepCertain(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + none() + } + + /** + * Holds if `n1` having certain type `t` at `path1` implies that `n2` has + * certain type `t` at `path2`, and vice versa. + */ + default predicate inferStepSymmetricCertain( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { + none() + } - /** Gets the inferred certain type of `n` at `path`. */ - Type inferCertainTypeInput(AstNode n, TypePath path); + /** + * Gets the inferred certain type of `n` at `path`. + * + * This predicate will be included directly in the exposed `inferTypeCertain` predicate. + */ + default Type inferTypeCertainInput(AstNode n, TypePath path) { none() } /** - * Holds if the types of `n1` at `path1` and `n2` at `path2` are possibly equal, - * and type information should be allowed to flow in both directions between them. + * Holds if `n1` having type `t` at `path1` implies that `n2` has type `t` at `path2`, + * but not necessarily the other way around. */ - predicate typeEqualityInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + predicate inferStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2); /** - * Holds if the type tree of `n1` at `path1` should be equal to the type tree - * of `n2` at `path2`, but type information should only propagate from `n1` to - * `n2`. + * Holds if `n1` having type `t` at `path1` implies that `n2` has type `t` at `path2`, + * and vice versa. */ - predicate typeEqualityAsymmetricInput(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + predicate inferStepSymmetric(AstNode n1, TypePath path1, AstNode n2, TypePath path2); /** * Holds if `child` is a child of `parent` and the type of `parent` at `prefix` can be @@ -2362,9 +2378,15 @@ module Make1 Input1> { * * When `child` is unique, we also allow type information to flow from `parent` to `child`. */ - predicate parentChildType(AstNode parent, AstNode child, TypePath prefix); + default predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + none() + } - /** Gets the inferred type of `n` at `path`. */ + /** + * Gets the inferred type of `n` at `path`. + * + * This predicate will be included directly in the exposed `inferType` predicate. + */ Type inferTypeInput(AstNode n, TypePath path); } @@ -2372,14 +2394,16 @@ module Make1 Input1> { private import Input3 /** Provides logic for inferring certain type information. */ - module CertainTypeInference { + private module Certain { /** Gets the type of `n`, which has an explicit type annotation. */ pragma[nomagic] Type inferAnnotatedType(AstNode n, TypePath path) { result = getTypeAnnotation(n).getTypeAt(path) } - predicate certainTypeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + private predicate stepSymmetricCertain( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { path1.isEmpty() and path2.isEmpty() and ( @@ -2394,26 +2418,30 @@ module Make1 Input1> { n1 = n2.(ParenExpr).getExpr() ) or - certainTypeEqualityInput(n1, path1, n2, path2) + inferStepSymmetricCertain(n1, path1, n2, path2) + } + + predicate stepCertain(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + stepSymmetricCertain(n1, path1, n2, path2) + or + stepSymmetricCertain(n2, path2, n1, path1) + or + inferStepCertain(n1, path1, n2, path2) } pragma[nomagic] - private Type inferCertainTypeEquality(AstNode n, TypePath path) { + private Type inferTypeFromStepCertain(AstNode n, TypePath path) { exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferCertainType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - certainTypeEquality(n, prefix1, n2, prefix2) - or - certainTypeEquality(n2, prefix2, n, prefix1) + result = inferTypeCertain(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) and + stepCertain(n2, prefix2, n, prefix1) ) } private Type inferLogicalOperationType(AstNode n, TypePath path) { ( exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or - exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) //or - // exists(LogicalNotExpr lne | n = [lne, lne.getOperand()]) + exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) ) and result instanceof BoolType and path.isEmpty() @@ -2450,13 +2478,13 @@ module Make1 Input1> { /** Gets the inferred certain type of `n` at `path`. */ cached - Type inferCertainType(AstNode n, TypePath path) { + Type inferTypeCertain(AstNode n, TypePath path) { CachedStage::ref() and result = inferAnnotatedType(n, path) or - result = inferCertainTypeEquality(n, path) + result = inferTypeFromStepCertain(n, path) or - result = inferCertainTypeInput(n, path) + result = inferTypeCertainInput(n, path) or result = inferLogicalOperationType(n, path) or @@ -2473,7 +2501,7 @@ module Make1 Input1> { pragma[nomagic] private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { exists(TypePath path | - exists(inferCertainType(n, path)) and + exists(inferTypeCertain(n, path)) and path.isSnoc(prefix, tp) ) } @@ -2483,7 +2511,7 @@ module Make1 Input1> { */ pragma[nomagic] predicate hasInferredCertainType(AstNode n, TypePath path) { - exists(inferCertainType(n, path)) + exists(inferTypeCertain(n, path)) } /** @@ -2493,7 +2521,7 @@ module Make1 Input1> { bindingset[n, prefix, path, t] pragma[inline_late] predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { - inferCertainType(n, path) != t + inferTypeCertain(n, path) != t or // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, @@ -2504,15 +2532,27 @@ module Make1 Input1> { exists(TypePath suffix, TypeParameter tp, Type certainType | path = prefix.appendInverse(suffix) and tp = suffix.getHead() and - inferCertainType(n, prefix) = certainType and + inferTypeCertain(n, prefix) = certainType and not certainType.getATypeParameter() = tp ) } } - private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { - CertainTypeInference::certainTypeEquality(n1, path1, n2, path2) + predicate inferTypeCertain = Certain::inferTypeCertain/2; + + private predicate lubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + path2.isEmpty() and + ( + n1 = n2.(Switch).getCase(_).getBody() + or + n2 = any(ConditionalExpr ce | n1 = [ce.getThen(), ce.getElse()]) + ) or + inferLubStep(n1, path1, n2, path2) + } + + private predicate stepSymmetric(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { path1.isEmpty() and path2.isEmpty() and ( @@ -2525,50 +2565,76 @@ module Make1 Input1> { let.getLeftOperand() = n1 and let.getRightOperand() = n2 ) + or + exists(Switch switch | + n1 = switch.getExpr() and + n2 = switch.getCase(_).getAPattern() + ) ) or - typeEqualityInput(n1, path1, n2, path2) + inferStepSymmetric(n1, path1, n2, path2) + // or + // n2 = unique(AstNode child | parentChildType(n1, child, path1) | child) and + // path2.isEmpty() + } + + // private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { + // parentChildType(parent, child, prefix) and + // strictcount(AstNode child0 | parentChildType(parent, child0, prefix) | child0) > 1 + // or + // inferStep(child, TypePath::nil(), parent, prefix) + // } + private predicate step(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + // lubCoercion(n2, n1, path2) and + // path1.isEmpty() + // or + // exists(AstNode mid, TypePath pathMid, TypePath suffix | + // typeEquality(n1, pathMid, mid, path2) or + // typeEquality(mid, path2, n1, pathMid) + // | + // lubCoercion(mid, n2, suffix) and + // not lubCoercion(mid, n1, _) and + // path1 = pathMid.append(suffix) + // ) + // or + inferStep(n1, path1, n2, path2) or - n2 = unique(AstNode child | parentChildType(n1, child, path1) | child) and - path2.isEmpty() - } - - private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - parentChildType(parent, child, prefix) and - strictcount(AstNode child0 | parentChildType(parent, child0, prefix) | child0) > 1 + stepSymmetric(n1, path1, n2, path2) or - typeEqualityAsymmetricInput(child, TypePath::nil(), parent, prefix) + stepSymmetric(n2, path2, n1, path1) + or + Certain::stepCertain(n1, path1, n2, path2) + or + lubStep(n1, path1, n2, path2) + or + n2 = unique(AstNode n | lubStep(n, _, n1, _) | n) and + lubStep(n2, path2, n1, path1) } - private predicate typeEqualityAsymmetric( - AstNode n1, TypePath path1, AstNode n2, TypePath path2 - ) { - lubCoercion(n2, n1, path2) and - path1.isEmpty() - or - exists(AstNode mid, TypePath pathMid, TypePath suffix | - typeEquality(n1, pathMid, mid, path2) or - typeEquality(mid, path2, n1, pathMid) - | - lubCoercion(mid, n2, suffix) and - not lubCoercion(mid, n1, _) and - path1 = pathMid.append(suffix) + pragma[nomagic] + private Type inferTypeFromStep(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) and + step(n2, prefix2, n, prefix1) ) - or - typeEqualityAsymmetricInput(n1, path1, n2, path2) } pragma[nomagic] - private Type inferTypeEquality(AstNode n, TypePath path) { + private Type inferTypeFromReverseLubStepCand(AstNode n, TypePath path, TypePath prefix) { exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | result = inferType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - typeEquality(n, prefix1, n2, prefix2) - or - typeEquality(n2, prefix2, n, prefix1) - or - typeEqualityAsymmetric(n2, prefix2, n, prefix1) + path = prefix1.append(suffix) and + lubStep(n, prefix1, n2, prefix2) and + prefix = path.getAPrefix() + ) + } + + pragma[nomagic] + private Type inferTypeFromReverseLub(AstNode n, TypePath path) { + exists(TypePath prefix | + result = inferTypeFromReverseLubStepCand(n, path, prefix) and + hasUnknownTypeAt(n, prefix) ) } @@ -2578,18 +2644,20 @@ module Make1 Input1> { cached Type inferType(AstNode n, TypePath path) { CachedStage::ref() and - result = CertainTypeInference::inferCertainType(n, path) + result = inferTypeCertain(n, path) or // Don't propagate type information into a node which conflicts with certain // type information. forall(TypePath prefix | - CertainTypeInference::hasInferredCertainType(n, prefix) and + Certain::hasInferredCertainType(n, prefix) and prefix.isPrefixOf(path) | - not CertainTypeInference::certainTypeConflict(n, prefix, path, result) + not Certain::certainTypeConflict(n, prefix, path, result) ) and ( - result = inferTypeEquality(n, path) + result = inferTypeFromStep(n, path) + or + result = inferTypeFromReverseLub(n, path) or result = CheckContextTyping::check(n, path) or @@ -2646,7 +2714,7 @@ module Make1 Input1> { } pragma[nomagic] - private predicate hasUnknownTypeAt(AstNode n, TypePath path) { + predicate hasUnknownTypeAt(AstNode n, TypePath path) { inferType(n, path) instanceof UnknownType } @@ -2699,7 +2767,7 @@ module Make1 Input1> { /** Reverse references to the predicates that reference `ref()`. */ cached predicate revRef() { - (exists(CertainTypeInference::inferCertainType(_, _)) implies any()) + (exists(inferTypeCertain(_, _)) implies any()) or (exists(inferType(_, _)) implies any()) or From 69e7f6fbce14b87acecc3670cb0c44cfe54fe459 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Thu, 7 May 2026 11:59:02 +0200 Subject: [PATCH 10/12] wip10 --- .../internal/typeinference/TypeInference.qll | 25 ++--- .../typeinference/internal/TypeInference.qll | 98 ++++++++++++++----- 2 files changed, 80 insertions(+), 43 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 02b98f9990a3..ac9adc1f96ab 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -916,7 +916,7 @@ private predicate bodyReturns(Expr body, Expr e) { } pragma[nomagic] -private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePath prefix) { +private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) { // Normally, these are coercion sites, but in case a type is unknown we // allow for type information to flow from the type annotation. exists(TypeMention tm | result = tm.getTypeAt(path) | @@ -925,17 +925,12 @@ private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePa tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr() or tm = getReturnTypeMention(any(Function f | n = f.getBody())) - ) and - prefix = path.getAPrefix() -} - -private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { - exists(TypePath prefix | - result = inferUnknownTypeFromAnnotationCand(n, path, prefix) and - hasUnknownTypeAt(n, prefix) ) } +private predicate inferUnknownTypeFromAnnotation = + TopDownTyping::inferType/2; + pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -2819,7 +2814,7 @@ private Type inferFunctionCallType0( call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and result = TUnknownType() or - result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0) + result = inferCallTypeOut(call, derefChainBorrow, pos, n, path0) ) | if @@ -3630,11 +3625,10 @@ private Type inferForLoopExprType(AstNode n, TypePath path) { } pragma[nomagic] -private Type inferClosureExprBodyTypeCand(AstNode n, TypePath path, TypePath prefix) { +private Type inferClosureExprBodyTypeTopDown(AstNode n, TypePath path) { exists(ClosureExpr ce | n = ce.getClosureBody() and - result = inferType(ce, closureReturnPath().appendInverse(path)) and - prefix = path.getAPrefix() + result = inferType(ce, closureReturnPath().appendInverse(path)) ) } @@ -3661,10 +3655,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) { ) ) or - exists(TypePath prefix | - result = inferClosureExprBodyTypeCand(n, path, prefix) and - hasUnknownTypeAt(n, prefix) - ) + result = TopDownTyping::inferType(n, path) } pragma[nomagic] diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 0b10f4b1d98c..b8a1e1a806ec 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -147,7 +147,7 @@ signature module InputSig1 { /** * A special pseudo type used to represent cases where the actual type needs - * to be inferred from the context. For example, in + * to be inferred from the context in a top-down manner. For example, in * * ```rust * let x = Vec::new(); @@ -2146,13 +2146,14 @@ module Make1 Input1> { */ signature module InputSig3 { /** - * References to cached predicates that should be included to the cached - * stage of type inference. Such predicates should reference `CachedStage::ref`. + * A predicate used to reference cached predicates that should be included to the + * cached stage of type inference. Such predicates should themselves reference + * `CachedStage::ref`. */ default predicate cachedStageRevRef() { none() } /** - * Point this predicate to the `inferType` predicate in the output of this module. + * Point this predicate to the `inferType` predicate from the output of this module. * * Needed to be able to refer to `inferType` in default signature implementations. */ @@ -2278,7 +2279,8 @@ module Make1 Input1> { } /** - * A position where a callable can have a declared type. + * A position where a callable can have a declared type and a call can have + * an inferred type. */ class TypePosition { /** Holds if this position represents the return type of a callable. */ @@ -2288,7 +2290,14 @@ module Make1 Input1> { string toString(); } - /** A context needed to resolve calls. */ + /** + * A context needed to resolve calls. + * + * For example, in Rust, we need an additional context to represent the + * candidate receiver type when resolving method calls. + * + * When not used, simply instantiate this class with `Unit`. + */ bindingset[this] class CallResolutionContext { /** Gets a textual representation of this context. */ @@ -2298,11 +2307,18 @@ module Make1 Input1> { /** A callable. */ class Callable { + /** Gets the type parameter at position `ppos` of this callable, if any. */ TypeParameter getTypeParameter(TypeParameterPosition ppos); + /** + * Gets an additional type parameter constraint for the given type parameter, + * which applies to this callable. For example, in Rust, a function can apply + * additional constraints on type parameters belonging to the `impl` block + * that the function is defined in. + */ TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp); - /* Gets the declared type of this callable at `path` for position `pos`. */ + /** Gets the declared type of this callable at `path` for position `pos`. */ Type getDeclaredType(TypePosition pos, TypePath path); /** Gets a textual representation of this callable. */ @@ -2312,19 +2328,31 @@ module Make1 Input1> { Location getLocation(); } + /** A call expression. */ class Call extends Expr { + /** Gets the explicit type argument at position `apos` and `path` for this call, if any. */ Type getTypeArgument(TypeArgumentPosition apos, TypePath path); + /** Gets the AST node corresponding to the position `pos` of this call. */ AstNode getNodeAt(TypePosition pos); - /** Gets the target of this call. */ + /** + * Gets the target of this call, to be used when inferring certain types. + */ Callable getTargetCertain(); - /** Gets the target of this call. */ + /** Gets the target of this call in the given context. */ Callable getTarget(CallResolutionContext ctx); } - /** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */ + /** + * Gets the inferred type of `call` at `path` and position `pos` in context `ctx`. + * + * By default, this is the inferred type of the node at the given position, but + * in for example Rust, the inferred type of the receiver of a method call needs + * to take the call context into account, in order to use the correct candidate + * receiver type. + */ bindingset[ctx] default Type inferCallTypeIn( Call call, CallResolutionContext ctx, TypePosition pos, TypePath path @@ -2556,9 +2584,9 @@ module Make1 Input1> { path1.isEmpty() and path2.isEmpty() and ( - exists(Assignment a | - a.getLeftOperand() = n1 and - a.getRightOperand() = n2 + exists(AssignExpr ae | + ae.getLeftOperand() = n1 and + ae.getRightOperand() = n2 ) or exists(LetDeclaration let | @@ -2621,22 +2649,16 @@ module Make1 Input1> { } pragma[nomagic] - private Type inferTypeFromReverseLubStepCand(AstNode n, TypePath path, TypePath prefix) { + private Type inferTypeFromLubStepTopDown(AstNode n, TypePath path) { exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | result = inferType(n2, prefix2.appendInverse(suffix)) and path = prefix1.append(suffix) and - lubStep(n, prefix1, n2, prefix2) and - prefix = path.getAPrefix() + lubStep(n, prefix1, n2, prefix2) ) } - pragma[nomagic] - private Type inferTypeFromReverseLub(AstNode n, TypePath path) { - exists(TypePath prefix | - result = inferTypeFromReverseLubStepCand(n, path, prefix) and - hasUnknownTypeAt(n, prefix) - ) - } + private predicate inferTypeFromReverseLub = + TopDownTyping::inferType/2; /** * Gets the inferred type of `n` at `path`. @@ -2705,22 +2727,46 @@ module Make1 Input1> { private module CallMatching = MatchingWithEnvironment; - pragma[nomagic] Type inferCallTypeOut( - Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path + Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path ) { n = call.getNodeAt(pos) and result = CallMatching::inferAccessType(call, ctx, pos, path) } pragma[nomagic] - predicate hasUnknownTypeAt(AstNode n, TypePath path) { + private predicate hasUnknownTypeAt(AstNode n, TypePath path) { inferType(n, path) instanceof UnknownType } pragma[nomagic] private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } + signature Type inferTypeTopDownSig(AstNode n, TypePath path); + + /** + * Given a predicate `inferTypeTopDown` for inferring the type of an AST node `n` + * top-down from a context, this module exposes the predicate `inferType`, which + * restricts type information to only flow top-down into `n` when `n` has an + * explicit unknown type. + */ + module TopDownTyping { + pragma[nomagic] + private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) { + result = inferTypeTopDown(n, path) and + hasUnknownType(n) and + prefix = path.getAPrefix() + } + + pragma[nomagic] + Type inferType(AstNode n, TypePath path) { + exists(TypePath prefix | + result = inferTypeTopDown(n, prefix, path) and + hasUnknownTypeAt(n, prefix) + ) + } + } + signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path); /** From 43e54530807d1d8767b939f116f1cb18ddfcc299 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Thu, 7 May 2026 15:18:43 +0200 Subject: [PATCH 11/12] wip11 --- .../internal/typeinference/TypeInference.qll | 159 ++++++++---------- .../PathResolutionConsistency.expected | 2 + .../type-inference/type-inference.expected | 5 - .../typeinference/internal/TypeInference.qll | 116 +++++++++---- 4 files changed, 155 insertions(+), 127 deletions(-) create mode 100644 rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index ac9adc1f96ab..910e988d9be4 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -415,15 +415,30 @@ private module Input3 implements InputSig3 { } bindingset[derefChainBorrow] - Type inferCallTypeIn(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) { + Type inferCallTypeBottomUp(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) { result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path) } - Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path) { - result = inferFunctionCallTypeNonSelf(n, pos, path) + Type inferCallReturnType(AstNode n, TypePath path) { + exists(Call call, TypePath path0 | + result = inferCallReturnType(call, _, n, path0) and + if + // index expression `x[i]` desugars to `*x.index(i)`, so we must account for + // the implicit deref + call instanceof IndexExpr + then path0.isCons(getRefTypeParameter(_), path) + else path = path0 + ) + } + + Type inferCallArgumentTypeTopDown(AstNode n, TypePath path) { + exists(FunctionCallMatchingInput::Access call, FunctionPosition pos | + result = inferCallArgumentTypeTopDown(call, pos, n, _, _, path) and + not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) + ) or exists(FunctionCallMatchingInput::Access a | - result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and + result = inferFunctionCallSelfArgumentTypeTopDown(a, n, DerefChain::nil(), path) and if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() then not path.isEmpty() else any() @@ -459,7 +474,7 @@ private module Input3 implements InputSig3 { ) } - Type inferTypeCertainInput(AstNode n, TypePath path) { + Type inferTypeCertainSpecific(AstNode n, TypePath path) { result = inferFunctionBodyType(n, path) or result = inferLiteralType(n, path, true) @@ -580,36 +595,31 @@ private module Input3 implements InputSig3 { prefix1.isEmpty() } - /** - * Holds if `child` is a child of `parent`, and the Rust compiler applies [least - * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of - * `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. - * - * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound - */ - predicate inferLubStep(AstNode child, TypePath path1, AstNode parent, TypePath prefix) { + predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { path1.isEmpty() and ( - parent = any(ArrayListExpr ale | child = ale.getAnExpr()) and - prefix = TypePath::singleton(getArrayTypeParameter()) + n2 = any(ArrayListExpr ale | n1 = ale.getAnExpr()) and + path2 = TypePath::singleton(getArrayTypeParameter()) or - bodyReturns(parent, child) and - prefix.isEmpty() + bodyReturns(n2, n1) and + path2.isEmpty() or exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = + n1 = [n2.(RangeExpr).getStart(), n2.(RangeExpr).getEnd()] and + path2 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) + s = getRangeType(n2) ) ) } - Type inferTypeInput(AstNode n, TypePath path) { + Type inferTypeTopDown(AstNode n, TypePath path) { + result = inferTypeFromAnnotationTopDown(n, path) + or + result = inferClosureExprBodyTypeTopDown(n, path) + } + + Type inferTypeSpecific(AstNode n, TypePath path) { result = inferAssignmentOperationType(n, path) or result = inferConstructionType(n, path) @@ -634,7 +644,7 @@ private module Input3 implements InputSig3 { or result = inferDeconstructionPatType(n, path) or - result = inferUnknownTypeFromAnnotation(n, path) + result = inferUnknownType(n, path) } } @@ -928,9 +938,6 @@ private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) { ) } -private predicate inferUnknownTypeFromAnnotation = - TopDownTyping::inferType/2; - pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -2802,36 +2809,13 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } pragma[nomagic] -private Type inferFunctionCallType0( +private Type inferCallArgumentTypeTopDown( FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path ) { - exists(TypePath path0 | - exists(string derefChainBorrow | - FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) - | - n = call.getNodeAt(pos) and - call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and - result = TUnknownType() - or - result = inferCallTypeOut(call, derefChainBorrow, pos, n, path0) - ) - | - if - // index expression `x[i]` desugars to `*x.index(i)`, so we must account for - // the implicit deref - pos.isReturn() and - call instanceof IndexExpr - then path0.isCons(getRefTypeParameter(_), path) - else path = path0 - ) -} - -pragma[nomagic] -private Type inferFunctionCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePath path) { - exists(FunctionCallMatchingInput::Access call | - result = inferFunctionCallType0(call, pos, n, _, _, path) and - not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) + exists(string derefChainBorrow | + FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) and + result = inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path) ) } @@ -2843,12 +2827,12 @@ private Type inferFunctionCallTypeNonSelf(AstNode n, FunctionPosition pos, TypeP * empty, at which point the inferred type can be applied back to `n`. */ pragma[nomagic] -private Type inferFunctionCallTypeSelf( +private Type inferFunctionCallSelfArgumentTypeTopDown( FunctionCallMatchingInput::Access call, AstNode n, DerefChain derefChain, TypePath path ) { exists(FunctionPosition pos, BorrowKind borrow, TypePath path0 | call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) and - result = inferFunctionCallType0(call, pos, n, derefChain, borrow, path0) + result = inferCallArgumentTypeTopDown(call, pos, n, derefChain, borrow, path0) | borrow.isNoBorrow() and path = path0 @@ -2865,7 +2849,7 @@ private Type inferFunctionCallTypeSelf( DerefChain derefChain0, Type t0, TypePath path0, DerefImplItemNode impl, Type selfParamType, TypePath selfPath | - t0 = inferFunctionCallTypeSelf(call, n, derefChain0, path0) and + t0 = inferFunctionCallSelfArgumentTypeTopDown(call, n, derefChain0, path0) and derefChain0.isCons(impl, derefChain) and selfParamType = impl.resolveSelfTypeAt(selfPath) | @@ -3041,17 +3025,37 @@ private module ConstructionMatching = Matching; pragma[nomagic] private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { - exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) | + exists(ConstructionMatchingInput::Access a | + n = a.getNodeAt(pos) and result = ConstructionMatching::inferAccessType(a, pos, path) - or - a.hasUnknownTypeAt(pos, path) and - result = TUnknownType() ) } private predicate inferConstructionType = CheckContextTyping::check/2; +pragma[nomagic] +private Type inferUnknownType(AstNode n, TypePath path) { + result = TUnknownType() and + ( + exists(FunctionCallMatchingInput::Access call, FunctionPosition pos | + n = call.getNodeAt(pos) and + call.hasUnknownTypeAt(_, pos, path) + ) + or + exists(ConstructionMatchingInput::Access a, FunctionPosition pos | + n = a.getNodeAt(pos) and + a.hasUnknownTypeAt(pos, path) + ) + or + exists(Param p | + not p.hasTypeRepr() and + n = p.getPat() and + path.isEmpty() + ) + ) +} + /** * A matching configuration for resolving types of operations like `a + b`. */ @@ -3633,29 +3637,14 @@ private Type inferClosureExprBodyTypeTopDown(AstNode n, TypePath path) { } pragma[nomagic] -private Type inferClosureExprType(AstNode n, TypePath path) { - exists(ClosureExpr ce | - n = ce and - ( - path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and - result.(TupleType).getArity() = ce.getNumberOfParams() - or - exists(TypePath path0 | - result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and - path = closureReturnPath().append(path0) - ) - ) - or - exists(Param p | - p = ce.getAParam() and - not p.hasTypeRepr() and - n = p.getPat() and - result = TUnknownType() and - path.isEmpty() - ) - ) +private Type inferClosureExprType(ClosureExpr ce, TypePath path) { + path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and + result.(TupleType).getArity() = ce.getNumberOfParams() or - result = TopDownTyping::inferType(n, path) + exists(TypePath suffix | + result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(suffix) and + path = closureReturnPath().append(suffix) + ) } pragma[nomagic] diff --git a/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected b/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected new file mode 100644 index 000000000000..cfad81d2796a --- /dev/null +++ b/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected @@ -0,0 +1,2 @@ +multipleResolvedTargets +| main.rs:218:20:218:25 | ... != ... | diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 9075fc17f3eb..94b98d92f7da 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -10250,7 +10250,6 @@ inferType | main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option | | main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & | -| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & | | main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T | | main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 | @@ -11591,8 +11590,6 @@ inferType | main.rs:2221:18:2221:21 | true | | {EXTERNAL LOCATION} | bool | | main.rs:2223:9:2223:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | -| main.rs:2223:9:2223:15 | S(...) | T | main.rs:2107:5:2107:19 | S | -| main.rs:2223:9:2223:15 | S(...) | T.T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2223:9:2223:31 | ... .my_add(...) | T | main.rs:2107:5:2107:19 | S | @@ -11611,8 +11608,6 @@ inferType | main.rs:2224:24:2224:27 | 3i64 | | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:15 | S(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | i64 | -| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | & | -| main.rs:2225:9:2225:15 | S(...) | T.TRef | {EXTERNAL LOCATION} | i64 | | main.rs:2225:9:2225:29 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S | | main.rs:2225:9:2225:29 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 | | main.rs:2225:11:2225:14 | 1i64 | | {EXTERNAL LOCATION} | i64 | diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index b8a1e1a806ec..c36f63c75fc5 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2352,16 +2352,41 @@ module Make1 Input1> { * in for example Rust, the inferred type of the receiver of a method call needs * to take the call context into account, in order to use the correct candidate * receiver type. + * + * The type information provided by this predicate is used to derive type information + * about the call via the call target, such as the return type. */ bindingset[ctx] - default Type inferCallTypeIn( + default Type inferCallTypeBottomUp( Call call, CallResolutionContext ctx, TypePosition pos, TypePath path ) { result = inferType(call.getNodeAt(pos), path) and exists(ctx) } - Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path); + /** + * Gets the inferred return type of `call` at `path`. + * + * When no post-processing is needed, simply implement this predicate as + * `result = inferCallReturnType(_, _, n, path)`. + */ + Type inferCallReturnType(AstNode n, TypePath path); + + /** + * Gets the top-down inferred type of `call` at `path` and argument position + * `pos`. + * + * This predicate is used to propagate type information from the call target + * into call arguments, for example when an implicitly typed lambda is passed + * as an argument. + * + * Type information is only propagated into arguments with an explicitly unknown + * type. + * + * When no call-context based post-processing is needed, simply implement this + * predicate as `result = inferCallArgumentTypeTopDown(_, _, _, n, path)`. + */ + Type inferCallArgumentTypeTopDown(AstNode n, TypePath path); /** * Holds if `n1` having certain type `t` at `path1` implies that `n2` has @@ -2386,7 +2411,7 @@ module Make1 Input1> { * * This predicate will be included directly in the exposed `inferTypeCertain` predicate. */ - default Type inferTypeCertainInput(AstNode n, TypePath path) { none() } + default Type inferTypeCertainSpecific(AstNode n, TypePath path) { none() } /** * Holds if `n1` having type `t` at `path1` implies that `n2` has type `t` at `path2`, @@ -2401,21 +2426,41 @@ module Make1 Input1> { predicate inferStepSymmetric(AstNode n1, TypePath path1, AstNode n2, TypePath path2); /** - * Holds if `child` is a child of `parent` and the type of `parent` at `prefix` can be - * inferred from the type of `child`. + * Holds if `n1` having type `t` at `path1` implies that `n2` has a type `lub` at + * `path2`, where `lub` is a least-upper-bound of the types of all the nodes that + * have lub steps into `n2`. * - * When `child` is unique, we also allow type information to flow from `parent` to `child`. + * For example, for a ternary conditional expression, there are lub steps from each + * of the branches into the conditional expression itself. + * + * We don't actually model the least-upper-bound computation, instead we interpret + * `inferLubStep(n1, path1, n2, path2)` as + * + * - `inferStep(n1, path1, n2, path2)`, that is type information flows directly into + * the lub, and + * - `inferStep(n2, path2, n1, path1)`, provided that `n1` is unique, that is, type + * type information flows from the lub back into the unique input `n1`, and + * - type information is allowed to flow from the lub into any of its inputs, provided + * that they have an explicitly unknown type. */ default predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { none() } + /** + * Gets the top-down inferred type of `n` at `path`. + * + * Type information is only propagated into nodes with an explicitly unknown + * type. + */ + default Type inferTypeTopDown(AstNode n, TypePath path) { none() } + /** * Gets the inferred type of `n` at `path`. * * This predicate will be included directly in the exposed `inferType` predicate. */ - Type inferTypeInput(AstNode n, TypePath path); + Type inferTypeSpecific(AstNode n, TypePath path); } module Make3 { @@ -2512,7 +2557,7 @@ module Make1 Input1> { or result = inferTypeFromStepCertain(n, path) or - result = inferTypeCertainInput(n, path) + result = inferTypeCertainSpecific(n, path) or result = inferLogicalOperationType(n, path) or @@ -2601,30 +2646,9 @@ module Make1 Input1> { ) or inferStepSymmetric(n1, path1, n2, path2) - // or - // n2 = unique(AstNode child | parentChildType(n1, child, path1) | child) and - // path2.isEmpty() } - // private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - // parentChildType(parent, child, prefix) and - // strictcount(AstNode child0 | parentChildType(parent, child0, prefix) | child0) > 1 - // or - // inferStep(child, TypePath::nil(), parent, prefix) - // } private predicate step(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { - // lubCoercion(n2, n1, path2) and - // path1.isEmpty() - // or - // exists(AstNode mid, TypePath pathMid, TypePath suffix | - // typeEquality(n1, pathMid, mid, path2) or - // typeEquality(mid, path2, n1, pathMid) - // | - // lubCoercion(mid, n2, suffix) and - // not lubCoercion(mid, n1, _) and - // path1 = pathMid.append(suffix) - // ) - // or inferStep(n1, path1, n2, path2) or stepSymmetric(n1, path1, n2, path2) @@ -2681,9 +2705,13 @@ module Make1 Input1> { or result = inferTypeFromReverseLub(n, path) or - result = CheckContextTyping::check(n, path) + result = inferCallReturnType(n, path) or - result = inferTypeInput(n, path) + result = TopDownTyping::inferType(n, path) + or + result = TopDownTyping::inferType(n, path) + or + result = inferTypeSpecific(n, path) ) } @@ -2720,20 +2748,34 @@ module Make1 Input1> { class Access extends CallFinal { bindingset[e] Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) { - result = inferCallTypeIn(this, e, apos, path) + result = inferCallTypeBottomUp(this, e, apos, path) } } } private module CallMatching = MatchingWithEnvironment; - Type inferCallTypeOut( + private Type inferCallTypeOut( Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path ) { n = call.getNodeAt(pos) and result = CallMatching::inferAccessType(call, ctx, pos, path) } + Type inferCallReturnType(Call call, CallResolutionContext ctx, AstNode n, TypePath path) { + exists(TypePosition pos | + result = inferCallTypeOut(call, ctx, pos, n, path) and + pos.isReturn() + ) + } + + Type inferCallArgumentTypeTopDown( + Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path + ) { + result = inferCallTypeOut(call, ctx, pos, n, path) and + not pos.isReturn() + } + pragma[nomagic] private predicate hasUnknownTypeAt(AstNode n, TypePath path) { inferType(n, path) instanceof UnknownType @@ -2742,18 +2784,18 @@ module Make1 Input1> { pragma[nomagic] private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } - signature Type inferTypeTopDownSig(AstNode n, TypePath path); + private signature Type inferTypeTopDownSig(AstNode n, TypePath path); /** - * Given a predicate `inferTypeTopDown` for inferring the type of an AST node `n` + * Given a predicate `infer` for inferring the type of an AST node `n` * top-down from a context, this module exposes the predicate `inferType`, which * restricts type information to only flow top-down into `n` when `n` has an * explicit unknown type. */ - module TopDownTyping { + private module TopDownTyping { pragma[nomagic] private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) { - result = inferTypeTopDown(n, path) and + result = infer(n, path) and hasUnknownType(n) and prefix = path.getAPrefix() } From ad4a5b7fdc57d4666d265cc8c85749d0cf9722c3 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Thu, 7 May 2026 19:58:20 +0200 Subject: [PATCH 12/12] wip12 --- .../internal/typeinference/TypeInference.qll | 93 ++++++++----------- .../typeinference/internal/TypeInference.qll | 52 ++--------- .../type-inference/type-inference.expected | 11 --- 3 files changed, 51 insertions(+), 105 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 910e988d9be4..d06f9d8aa71d 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -14,7 +14,6 @@ private import FunctionType private import FunctionOverloading as FunctionOverloading private import BlanketImplementation as BlanketImplementation private import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl -private import codeql.rust.internal.CachedStages private import codeql.typeinference.internal.TypeInference private import codeql.rust.frameworks.stdlib.Stdlib private import codeql.rust.frameworks.stdlib.Builtins as Builtins @@ -394,13 +393,7 @@ private module Input3 implements InputSig3 { } } - class Call extends Expr instanceof FunctionCallMatchingInput::Access { - Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { - result = super.getTypeArgument(apos, path) - } - - AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) } - + class Call extends FunctionCallMatchingInput::Access { /** Gets the target of this call. */ Callable getTargetCertain() { exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p | @@ -421,7 +414,7 @@ private module Input3 implements InputSig3 { Type inferCallReturnType(AstNode n, TypePath path) { exists(Call call, TypePath path0 | - result = inferCallReturnType(call, _, n, path0) and + result = M3::inferCallReturnType(call, _, n, path0) and if // index expression `x[i]` desugars to `*x.index(i)`, so we must account for // the implicit deref @@ -598,11 +591,15 @@ private module Input3 implements InputSig3 { predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { path1.isEmpty() and ( - n2 = any(ArrayListExpr ale | n1 = ale.getAnExpr()) and + n1 = n2.(ArrayListExpr).getAnExpr() and path2 = TypePath::singleton(getArrayTypeParameter()) or - bodyReturns(n2, n1) and - path2.isEmpty() + exists(ReturnExpr re, Rust::Callable c | + n1 = re.getExpr() and + c = re.getEnclosingCallable() and + n2 = c.getBody() and + path2.isEmpty() + ) or exists(Struct s | n1 = [n2.(RangeExpr).getStart(), n2.(RangeExpr).getEnd()] and @@ -617,14 +614,22 @@ private module Input3 implements InputSig3 { result = inferTypeFromAnnotationTopDown(n, path) or result = inferClosureExprBodyTypeTopDown(n, path) + or + exists(FunctionPosition pos | not pos.isReturn() | + result = inferConstructionType(n, pos, path) + or + result = inferOperationType(n, pos, path) + ) } Type inferTypeSpecific(AstNode n, TypePath path) { result = inferAssignmentOperationType(n, path) or - result = inferConstructionType(n, path) - or - result = inferOperationType(n, path) + exists(FunctionPosition pos | pos.isReturn() | + result = inferConstructionType(n, pos, path) + or + result = inferOperationType(n, pos, path) + ) or result = inferFieldExprType(n, path) or @@ -650,7 +655,12 @@ private module Input3 implements InputSig3 { private module M3 = Make3; -import M3 +// import M3 +predicate inferType = M3::inferType/1; + +predicate inferType = M3::inferType/2; + +predicate inferTypeCertain = M3::inferTypeCertain/2; module Consistency { import M2::Consistency @@ -917,14 +927,6 @@ private Struct getRangeType(RangeExpr re) { result instanceof RangeToInclusiveStruct } -private predicate bodyReturns(Expr body, Expr e) { - exists(ReturnExpr re, Callable c | - e = re.getExpr() and - c = re.getEnclosingCallable() and - body = c.getBody() - ) -} - pragma[nomagic] private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) { // Normally, these are coercion sites, but in case a type is unknown we @@ -1082,7 +1084,7 @@ private module ContextTyping { * context in which the call appears, for example a call like * `Default::default()`. */ - abstract class ContextTypedCallCand extends AstNode { + abstract class ContextTypedCallCand extends Expr { abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path); predicate hasTypeArgument(TypeArgumentPosition apos) { exists(this.getTypeArgument(apos, _)) } @@ -2653,7 +2655,9 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput ) } - abstract class Access extends ContextTyping::ContextTypedCallCand { + final class Access = AccessImpl; + + abstract private class AccessImpl extends ContextTyping::ContextTypedCallCand { abstract AstNode getNodeAt(FunctionPosition pos); bindingset[derefChainBorrow] @@ -2668,7 +2672,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput abstract predicate hasUnknownTypeAt(string derefChainBorrow, FunctionPosition pos, TypePath path); } - private class AssocFunctionCallAccess extends Access instanceof AssocFunctionResolution::AssocFunctionCall + private class AssocFunctionCallAccess extends AccessImpl instanceof AssocFunctionResolution::AssocFunctionCall { AssocFunctionCallAccess() { // handled in the `OperationMatchingInput` module @@ -2755,7 +2759,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } } - private class NonAssocFunctionCallAccess extends Access instanceof NonAssocCallExpr, + private class NonAssocFunctionCallAccess extends AccessImpl instanceof NonAssocCallExpr, CallExprImpl::CallExprCall { pragma[nomagic] @@ -2815,7 +2819,7 @@ private Type inferCallArgumentTypeTopDown( ) { exists(string derefChainBorrow | FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) and - result = inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path) + result = M3::inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path) ) } @@ -3024,16 +3028,13 @@ private module ConstructionMatchingInput implements MatchingInputSig { private module ConstructionMatching = Matching; pragma[nomagic] -private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { +private Type inferConstructionType(AstNode n, FunctionPosition pos, TypePath path) { exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) and result = ConstructionMatching::inferAccessType(a, pos, path) ) } -private predicate inferConstructionType = - CheckContextTyping::check/2; - pragma[nomagic] private Type inferUnknownType(AstNode n, TypePath path) { result = TUnknownType() and @@ -3119,7 +3120,7 @@ private module OperationMatchingInput implements MatchingInputSig { private module OperationMatching = Matching; pragma[nomagic] -private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) { +private Type inferOperationType(AstNode n, FunctionPosition pos, TypePath path) { exists(OperationMatchingInput::Access a | n = a.getNodeAt(pos) and result = OperationMatching::inferAccessType(a, pos, path) and @@ -3127,8 +3128,6 @@ private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePat ) } -private predicate inferOperationType = CheckContextTyping::check/2; - pragma[nomagic] private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) { exists(TypePath path | @@ -3153,6 +3152,7 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC */ cached StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { + M3::CachedStage::ref() and exists(string name, DataType ty | ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) | @@ -3174,6 +3174,7 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref */ cached TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { + M3::CachedStage::ref() and exists(int i | result = getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) @@ -3664,7 +3665,7 @@ private Type inferCastExprType(CastExpr ce, TypePath path) { /** Holds if `n` is implicitly dereferenced and/or borrowed. */ cached predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { - CachedStage::ref() and + M3::CachedStage::ref() and exists(BorrowKind bk | any(AssocFunctionResolution::AssocFunctionCall afc) .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and @@ -3691,6 +3692,7 @@ predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow */ cached Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { + M3::CachedStage::ref() and dispatch = false and result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() or @@ -3711,7 +3713,7 @@ private module Debug { exists(string filepath, int startline, int startcolumn, int endline, int endcolumn | result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and filepath.matches("%/main.rs") and - startline = 1102 + startline = 103 ) } @@ -3737,24 +3739,11 @@ private module Debug { t = self.getTypeAt(path) } - // predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { - // n = getRelevantLocatable() and - // t = inferFunctionCallType(n, path) - // } - predicate debugInferConstructionType(AstNode n, TypePath path, Type t) { - n = getRelevantLocatable() and - t = inferConstructionType(n, path) - } - predicate debugTypeMention(TypeMention tm, TypePath path, Type type) { tm = getRelevantLocatable() and tm.getTypeAt(path) = type } - // Type debugInferAnnotatedType(AstNode n, TypePath path) { - // n = getRelevantLocatable() and - // result = inferAnnotatedType(n, path) - // } pragma[nomagic] private int countTypesAtPath(AstNode n, TypePath path, Type t) { t = inferType(n, path) and @@ -3803,7 +3792,7 @@ private module Debug { c = max(countTypePaths(_, _, _)) } - Type debuginferTypeCertain(AstNode n, TypePath path) { + Type debugInferTypeCertain(AstNode n, TypePath path) { n = getRelevantLocatable() and result = inferTypeCertain(n, path) } diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index c36f63c75fc5..0f646ab3552c 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -2521,7 +2521,7 @@ module Make1 Input1> { } pragma[nomagic] - private Type getCertainCallExprType(Call call, TypePath path) { + private Type getCertainCallExprReturnType(Call call, TypePath path) { exists(TypePosition ret | ret.isReturn() and forex(Callable target | target = call.getTargetCertain() | @@ -2531,8 +2531,8 @@ module Make1 Input1> { } pragma[nomagic] - private Type inferCertainCallExprType(Call call, TypePath path) { - exists(Type ty, TypePath prefix | ty = getCertainCallExprType(call, prefix) | + private Type inferCertainCallExprReturnType(Call call, TypePath path) { + exists(Type ty, TypePath prefix | ty = getCertainCallExprReturnType(call, prefix) | exists( Callable target, TypePath suffix, TypeParameterPosition tppos, TypeArgumentPosition tapos @@ -2561,7 +2561,7 @@ module Make1 Input1> { or result = inferLogicalOperationType(n, path) or - result = inferCertainCallExprType(n, path) + result = inferCertainCallExprReturnType(n, path) or infersCertainTypeAt(n, path, result.getATypeParameter()) } @@ -2681,9 +2681,6 @@ module Make1 Input1> { ) } - private predicate inferTypeFromReverseLub = - TopDownTyping::inferType/2; - /** * Gets the inferred type of `n` at `path`. */ @@ -2703,7 +2700,7 @@ module Make1 Input1> { ( result = inferTypeFromStep(n, path) or - result = inferTypeFromReverseLub(n, path) + result = TopDownTyping::inferType(n, path) or result = inferCallReturnType(n, path) or @@ -2755,7 +2752,7 @@ module Make1 Input1> { private module CallMatching = MatchingWithEnvironment; - private Type inferCallTypeOut( + private Type inferCallType( Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path ) { n = call.getNodeAt(pos) and @@ -2764,7 +2761,7 @@ module Make1 Input1> { Type inferCallReturnType(Call call, CallResolutionContext ctx, AstNode n, TypePath path) { exists(TypePosition pos | - result = inferCallTypeOut(call, ctx, pos, n, path) and + result = inferCallType(call, ctx, pos, n, path) and pos.isReturn() ) } @@ -2772,8 +2769,9 @@ module Make1 Input1> { Type inferCallArgumentTypeTopDown( Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path ) { - result = inferCallTypeOut(call, ctx, pos, n, path) and - not pos.isReturn() + result = inferCallType(call, ctx, pos, n, path) and + not pos.isReturn() and + hasUnknownType(n) } pragma[nomagic] @@ -2809,36 +2807,6 @@ module Make1 Input1> { } } - signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path); - - /** - * Given a predicate `inferCallType` for inferring the type of a call at a given - * position, this module exposes the predicate `check`, which wraps the input - * predicate and checks that types are only propagated into arguments when they - * are context-typed. - */ - module CheckContextTyping { - pragma[nomagic] - private Type inferCallNonReturnType(AstNode n, TypePath prefix, TypePath path) { - exists(TypePosition pos | - result = inferCallType(n, pos, path) and - hasUnknownType(n) and - not pos.isReturn() and - prefix = path.getAPrefix() - ) - } - - pragma[nomagic] - Type check(AstNode n, TypePath path) { - result = inferCallType(n, any(TypePosition pos | pos.isReturn()), path) - or - exists(TypePath prefix | - result = inferCallNonReturnType(n, prefix, path) and - hasUnknownTypeAt(n, prefix) - ) - } - } - /** * Gets the inferred root type of `n`, if any. */ diff --git a/swift/ql/test/library-tests/type-inference/type-inference.expected b/swift/ql/test/library-tests/type-inference/type-inference.expected index 79a3599aa1f3..e69de29bb2d1 100644 --- a/swift/ql/test/library-tests/type-inference/type-inference.expected +++ b/swift/ql/test/library-tests/type-inference/type-inference.expected @@ -1,11 +0,0 @@ -| context.swift:17:10:17:10 | C.init() | Unexpected result: target=init() | -| context.swift:17:10:17:12 | call to C.init() | Unexpected result: target=init() | -| context.swift:17:14:18:1 | // $ type=C\n | Missing result: type=C | -| context.swift:25:11:25:11 | A.init() | Unexpected result: target=init() | -| context.swift:25:11:25:13 | call to A.init() | Unexpected result: target=init() | -| context.swift:26:19:26:19 | D.init() | Unexpected result: target=init() | -| context.swift:26:19:26:21 | call to D.init() | Unexpected result: target=init() | -| context.swift:26:25:26:25 | B.init() | Unexpected result: target=init() | -| context.swift:26:25:26:27 | call to B.init() | Unexpected result: target=init() | -| file://:0:0:0:0 | A.init() | Unexpected result: target=init() | -| file://:0:0:0:0 | call to A.init() | Unexpected result: target=init() |