Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,10 @@ extractUnsupported markerName = \case

compileExprWithDefs
:: CompilingDefault uni fun m ann
=> GHC.CoreExpr
=> Maybe GHC.Type
-> GHC.CoreExpr
-> m (PIRTerm uni fun)
compileExprWithDefs e = do
compileExprWithDefs mTargetTy e = do
-- Order matters here. Generlly, Once that define types should go before anything that defines
-- terms. Otherwise, type definitions might get ignored if they appear in types of term definitions.
defineBoolType
Expand All @@ -1761,7 +1762,38 @@ compileExprWithDefs e = do
defineBuiltinTerms
defineIntegerNegate
defineFix
compileExpr Nothing e
case mTargetTy of
Just targetTy | Just e' <- retargetLeadingAnyBinders targetTy e -> compileExpr Nothing e'
_ -> compileExpr Nothing e

retargetLeadingAnyBinders :: GHC.Type -> GHC.CoreExpr -> Maybe GHC.CoreExpr
retargetLeadingAnyBinders = go False . stripForAllTyCoVars
where
stripForAllTyCoVars t =
case GHC.splitForAllTyCoVar_maybe t of
Just (_tv, body) -> stripForAllTyCoVars body
Nothing -> t

go sawErasedLam targetTy e = case e of
GHC.Tick tick body -> GHC.Tick tick <$> go sawErasedLam targetTy body
GHC.Cast body co -> (`GHC.Cast` co) <$> go sawErasedLam targetTy body
GHC.Lam b body
| GHC.isTyVar b -> GHC.Lam b <$> go sawErasedLam targetTy body
| isAnyBinder b
, Just (_t, _m, domTy, codTy) <- GHC.splitFunTy_maybe targetTy ->
let b' = GHC.setVarType b domTy
body' = retagBinderUses b b' body
in GHC.Lam b' <$> go True codTy body'
| otherwise -> Nothing
_ -> if sawErasedLam then Just e else Nothing

isAnyBinder b =
case GHC.splitTyConApp_maybe (GHC.varType b) of
Just (tc, _) -> tc == GHC.anyTyCon
Nothing -> False

retagBinderUses old new body =
GHC.substExpr (GHC.extendIdSubst (GHC.mkEmptySubst (GHC.mkInScopeSet (GHC.exprFreeVars body))) old (GHC.Var new)) body

{- Note [We always need DEFAULT]
GHC can be clever and omit case alternatives sometimes, typically when the typechecker says a case
Expand Down
2 changes: 1 addition & 1 deletion plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs-boot
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ compileExpr
=> Maybe GHC.RealSrcSpan -> GHC.CoreExpr -> m (PIRTerm uni fun)
compileExprWithDefs
:: CompilingDefault uni fun m ann
=> GHC.CoreExpr -> m (PIRTerm uni fun)
=> Maybe GHC.Type -> GHC.CoreExpr -> m (PIRTerm uni fun)
34 changes: 31 additions & 3 deletions plutus-tx-plugin/src/PlutusTx/Plugin/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ compileMarkedExpr _locStr codeTy origE = do

((pirP, uplcP), covIdx) <-
runWriterT . runQuoteT . flip runReaderT ctx . flip evalStateT st $
runCompiler moduleNameStr opts origE'
runCompiler moduleNameStr opts (targetTypeForExpr codeTy origE') origE'

-- serialize the PIR, PLC, and coverageindex outputs into a bytestring.
bsPir <- makeByteStringLiteral $ flat pirP
Expand All @@ -664,6 +664,33 @@ compileMarkedExpr _locStr codeTy origE = do
`GHC.App` bsPir
`GHC.App` covIdxFlat

targetTypeForExpr :: GHC.Type -> GHC.CoreExpr -> Maybe GHC.Type
targetTypeForExpr ty expr
| isFunctionReturningUnit ty' = if hasLeadingTermLambda expr then Just ty' else Nothing
| otherwise = Nothing
where
ty' = stripForAllTyCoVars ty

stripForAllTyCoVars t =
case GHC.splitForAllTyCoVar_maybe t of
Just (_tv, body) -> stripForAllTyCoVars body
Nothing -> t

isFunctionReturningUnit t =
case GHC.splitFunTy_maybe t of
Just (_mult, _vis, _dom, cod) -> isFunctionReturningUnit cod
Nothing -> case GHC.splitTyConApp_maybe t of
Just (tc, _) -> tc == GHC.unitTyCon
Nothing -> False

hasLeadingTermLambda = \case
GHC.Tick _ body -> hasLeadingTermLambda body
GHC.Cast body _ -> hasLeadingTermLambda body
GHC.Lam b body
| GHC.isTyVar b -> hasLeadingTermLambda body
| otherwise -> True
_ -> False

{-| The GHC.Core to PIR to PLC compiler pipeline. Returns both the PIR and PLC output.
It invokes the whole compiler chain: Core expr -> PIR expr -> PLC expr -> UPLC expr. -}
runCompiler
Expand All @@ -679,9 +706,10 @@ runCompiler
)
=> String
-> PluginOptions
-> Maybe GHC.Type
-> GHC.CoreExpr
-> m (PIRProgram uni fun, UPLCProgram uni fun)
runCompiler moduleName opts expr = do
runCompiler moduleName opts mTargetTy expr = do
GHC.DynFlags {GHC.extensions = extensions} <- asks ccFlags
let
enabledExtensions =
Expand Down Expand Up @@ -818,7 +846,7 @@ runCompiler moduleName opts expr = do
(opts ^. posCertifiedOptsOnly)

-- GHC.Core -> Pir translation.
pirT <- original <$> (PIR.runDefT annMayInline $ compileExprWithDefs expr)
pirT <- original <$> (PIR.runDefT annMayInline $ compileExprWithDefs mTargetTy expr)
let pirP = PIR.Program noProvenance plcVersion pirT
when (opts ^. posDumpPir) . liftIO $
dumpFlat (void pirP) "initial PIR program" (moduleName ++ "_initial.pir-flat")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(let
(nonrec)
(datatypebind
(datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit))
)
(lam ds (con data) (lam ds (con data) Unit))
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(let
(nonrec)
(datatypebind
(datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit))
)
(lam ds (con data) (lam ds (con data) (lam ds (con data) Unit)))
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(let
(nonrec)
(datatypebind
(datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit))
)
(lam ds (con data) (lam ds (con data) Unit))
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(let
(nonrec)
(datatypebind
(datatype (tyvardecl Unit (type)) Unit_match (vardecl Unit Unit))
)
(lam ds (con data) (lam ds (con data) (lam ds (con data) Unit)))
)
8 changes: 8 additions & 0 deletions plutus-tx-plugin/test/TH/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ tests =
, goldenEvalCekLog "traceDirect" traceDirect
, goldenEvalCekLog "tracePrelude" tracePrelude
, goldenEvalCekLog "traceRepeatedly" traceRepeatedly
, goldenPir "ignoredUntypedLambdaArguments" ignoredUntypedLambdaArguments
, goldenPir "ignoredUntypedThreeLambdaArguments" ignoredUntypedThreeLambdaArguments
, -- want to see the raw structure, so using Show
nestedGoldenVsDoc "someData" "" (pretty $ Haskell.show someData)
]
Expand Down Expand Up @@ -85,3 +87,9 @@ traceRepeatedly =
in i3
||]
)

ignoredUntypedLambdaArguments :: CompiledCode (BuiltinData -> BuiltinData -> ())
ignoredUntypedLambdaArguments = $$(compile [||\_ _ -> ()||])

ignoredUntypedThreeLambdaArguments :: CompiledCode (BuiltinData -> BuiltinData -> BuiltinData -> ())
ignoredUntypedThreeLambdaArguments = $$(compile [||\_ _ _ -> ()||])