diff --git a/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs b/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs index ba1204e504c..1d6c4cc7fb6 100644 --- a/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs +++ b/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs @@ -1750,9 +1750,10 @@ extractUnsupported markerName = \case compileExprWithDefs :: CompilingDefault uni fun m ann - => GHC.CoreExpr + => Maybe GHC.Type + -> GHC.CoreExpr -> m (PIRTerm uni fun) -compileExprWithDefs e = do +compileExprWithDefs mTargetTy e = do -- Order matters here. Generlly, Once that define types should go before anything that defines -- terms. Otherwise, type definitions might get ignored if they appear in types of term definitions. defineBoolType @@ -1761,7 +1762,38 @@ compileExprWithDefs e = do defineBuiltinTerms defineIntegerNegate defineFix - compileExpr Nothing e + case mTargetTy of + Just targetTy | Just e' <- retargetLeadingAnyBinders targetTy e -> compileExpr Nothing e' + _ -> compileExpr Nothing e + +retargetLeadingAnyBinders :: GHC.Type -> GHC.CoreExpr -> Maybe GHC.CoreExpr +retargetLeadingAnyBinders = go False . stripForAllTyCoVars + where + stripForAllTyCoVars t = + case GHC.splitForAllTyCoVar_maybe t of + Just (_tv, body) -> stripForAllTyCoVars body + Nothing -> t + + go sawErasedLam targetTy e = case e of + GHC.Tick tick body -> GHC.Tick tick <$> go sawErasedLam targetTy body + GHC.Cast body co -> (`GHC.Cast` co) <$> go sawErasedLam targetTy body + GHC.Lam b body + | GHC.isTyVar b -> GHC.Lam b <$> go sawErasedLam targetTy body + | isAnyBinder b + , Just (_t, _m, domTy, codTy) <- GHC.splitFunTy_maybe targetTy -> + let b' = GHC.setVarType b domTy + body' = retagBinderUses b b' body + in GHC.Lam b' <$> go True codTy body' + | otherwise -> Nothing + _ -> if sawErasedLam then Just e else Nothing + + isAnyBinder b = + case GHC.splitTyConApp_maybe (GHC.varType b) of + Just (tc, _) -> tc == GHC.anyTyCon + Nothing -> False + + retagBinderUses old new body = + GHC.substExpr (GHC.extendIdSubst (GHC.mkEmptySubst (GHC.mkInScopeSet (GHC.exprFreeVars body))) old (GHC.Var new)) body {- Note [We always need DEFAULT] GHC can be clever and omit case alternatives sometimes, typically when the typechecker says a case diff --git a/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs-boot b/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs-boot index a1064b7d220..82707676b35 100644 --- a/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs-boot +++ b/plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs-boot @@ -14,4 +14,4 @@ compileExpr => Maybe GHC.RealSrcSpan -> GHC.CoreExpr -> m (PIRTerm uni fun) compileExprWithDefs :: CompilingDefault uni fun m ann - => GHC.CoreExpr -> m (PIRTerm uni fun) + => Maybe GHC.Type -> GHC.CoreExpr -> m (PIRTerm uni fun) diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin/Common.hs b/plutus-tx-plugin/src/PlutusTx/Plugin/Common.hs index c233a3a1cc8..bbcde26f89b 100644 --- a/plutus-tx-plugin/src/PlutusTx/Plugin/Common.hs +++ b/plutus-tx-plugin/src/PlutusTx/Plugin/Common.hs @@ -647,7 +647,7 @@ compileMarkedExpr _locStr codeTy origE = do ((pirP, uplcP), covIdx) <- runWriterT . runQuoteT . flip runReaderT ctx . flip evalStateT st $ - runCompiler moduleNameStr opts origE' + runCompiler moduleNameStr opts (targetTypeForExpr codeTy origE') origE' -- serialize the PIR, PLC, and coverageindex outputs into a bytestring. bsPir <- makeByteStringLiteral $ flat pirP @@ -664,6 +664,33 @@ compileMarkedExpr _locStr codeTy origE = do `GHC.App` bsPir `GHC.App` covIdxFlat +targetTypeForExpr :: GHC.Type -> GHC.CoreExpr -> Maybe GHC.Type +targetTypeForExpr ty expr + | isFunctionReturningUnit ty' = if hasLeadingTermLambda expr then Just ty' else Nothing + | otherwise = Nothing + where + ty' = stripForAllTyCoVars ty + + stripForAllTyCoVars t = + case GHC.splitForAllTyCoVar_maybe t of + Just (_tv, body) -> stripForAllTyCoVars body + Nothing -> t + + isFunctionReturningUnit t = + case GHC.splitFunTy_maybe t of + Just (_mult, _vis, _dom, cod) -> isFunctionReturningUnit cod + Nothing -> case GHC.splitTyConApp_maybe t of + Just (tc, _) -> tc == GHC.unitTyCon + Nothing -> False + + hasLeadingTermLambda = \case + GHC.Tick _ body -> hasLeadingTermLambda body + GHC.Cast body _ -> hasLeadingTermLambda body + GHC.Lam b body + | GHC.isTyVar b -> hasLeadingTermLambda body + | otherwise -> True + _ -> False + {-| The GHC.Core to PIR to PLC compiler pipeline. Returns both the PIR and PLC output. It invokes the whole compiler chain: Core expr -> PIR expr -> PLC expr -> UPLC expr. -} runCompiler @@ -679,9 +706,10 @@ runCompiler ) => String -> PluginOptions + -> Maybe GHC.Type -> GHC.CoreExpr -> m (PIRProgram uni fun, UPLCProgram uni fun) -runCompiler moduleName opts expr = do +runCompiler moduleName opts mTargetTy expr = do GHC.DynFlags {GHC.extensions = extensions} <- asks ccFlags let enabledExtensions = @@ -818,7 +846,7 @@ runCompiler moduleName opts expr = do (opts ^. posCertifiedOptsOnly) -- GHC.Core -> Pir translation. - pirT <- original <$> (PIR.runDefT annMayInline $ compileExprWithDefs expr) + pirT <- original <$> (PIR.runDefT annMayInline $ compileExprWithDefs mTargetTy expr) let pirP = PIR.Program noProvenance plcVersion pirT when (opts ^. posDumpPir) . liftIO $ dumpFlat (void pirP) "initial PIR program" (moduleName ++ "_initial.pir-flat") diff --git a/plutus-tx-plugin/test/TH/9.12/ignoredUntypedLambdaArguments.golden.pir b/plutus-tx-plugin/test/TH/9.12/ignoredUntypedLambdaArguments.golden.pir new file mode 100644 index 00000000000..8130f5fe4fd --- /dev/null +++ b/plutus-tx-plugin/test/TH/9.12/ignoredUntypedLambdaArguments.golden.pir @@ -0,0 +1,7 @@ +(let + (nonrec) + (datatypebind + (datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit)) + ) + (lam ds (con data) (lam ds (con data) Unit)) +) \ No newline at end of file diff --git a/plutus-tx-plugin/test/TH/9.12/ignoredUntypedThreeLambdaArguments.golden.pir b/plutus-tx-plugin/test/TH/9.12/ignoredUntypedThreeLambdaArguments.golden.pir new file mode 100644 index 00000000000..6e23b66b641 --- /dev/null +++ b/plutus-tx-plugin/test/TH/9.12/ignoredUntypedThreeLambdaArguments.golden.pir @@ -0,0 +1,7 @@ +(let + (nonrec) + (datatypebind + (datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit)) + ) + (lam ds (con data) (lam ds (con data) (lam ds (con data) Unit))) +) \ No newline at end of file diff --git a/plutus-tx-plugin/test/TH/9.6/ignoredUntypedLambdaArguments.golden.pir b/plutus-tx-plugin/test/TH/9.6/ignoredUntypedLambdaArguments.golden.pir new file mode 100644 index 00000000000..8130f5fe4fd --- /dev/null +++ b/plutus-tx-plugin/test/TH/9.6/ignoredUntypedLambdaArguments.golden.pir @@ -0,0 +1,7 @@ +(let + (nonrec) + (datatypebind + (datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit)) + ) + (lam ds (con data) (lam ds (con data) Unit)) +) \ No newline at end of file diff --git a/plutus-tx-plugin/test/TH/9.6/ignoredUntypedThreeLambdaArguments.golden.pir b/plutus-tx-plugin/test/TH/9.6/ignoredUntypedThreeLambdaArguments.golden.pir new file mode 100644 index 00000000000..6e23b66b641 --- /dev/null +++ b/plutus-tx-plugin/test/TH/9.6/ignoredUntypedThreeLambdaArguments.golden.pir @@ -0,0 +1,7 @@ +(let + (nonrec) + (datatypebind + (datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit)) + ) + (lam ds (con data) (lam ds (con data) (lam ds (con data) Unit))) +) \ No newline at end of file diff --git a/plutus-tx-plugin/test/TH/Spec.hs b/plutus-tx-plugin/test/TH/Spec.hs index 3879e8377bb..1edd7187c84 100644 --- a/plutus-tx-plugin/test/TH/Spec.hs +++ b/plutus-tx-plugin/test/TH/Spec.hs @@ -49,6 +49,8 @@ tests = , goldenEvalCekLog "traceDirect" traceDirect , goldenEvalCekLog "tracePrelude" tracePrelude , goldenEvalCekLog "traceRepeatedly" traceRepeatedly + , goldenPir "ignoredUntypedLambdaArguments" ignoredUntypedLambdaArguments + , goldenPir "ignoredUntypedThreeLambdaArguments" ignoredUntypedThreeLambdaArguments , -- want to see the raw structure, so using Show nestedGoldenVsDoc "someData" "" (pretty $ Haskell.show someData) ] @@ -85,3 +87,9 @@ traceRepeatedly = in i3 ||] ) + +ignoredUntypedLambdaArguments :: CompiledCode (BuiltinData -> BuiltinData -> ()) +ignoredUntypedLambdaArguments = $$(compile [||\_ _ -> ()||]) + +ignoredUntypedThreeLambdaArguments :: CompiledCode (BuiltinData -> BuiltinData -> BuiltinData -> ()) +ignoredUntypedThreeLambdaArguments = $$(compile [||\_ _ _ -> ()||])