Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions packages/typegpu/src/core/declare/tgpuDeclare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { type ResolvedSnippet, snip } from '../../data/snippet.ts';
import { Void } from '../../data/wgslTypes.ts';
import { $internal, $resolve } from '../../shared/symbols.ts';
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
import { mergeExternals, type ExternalMap, replaceExternalsInWgsl } from '../resolve/externals.ts';
import { type ExternalMap, replaceExternalsInWgsl } from '../resolve/externals.ts';

// ----------
// Public API
Expand Down Expand Up @@ -33,26 +33,29 @@ export function declare(declaration: string): TgpuDeclare {

class TgpuDeclareImpl implements TgpuDeclare, SelfResolvable {
readonly [$internal] = true;
#externalsToApply: ExternalMap[] = [];
#externals: ExternalMap | undefined;
#declaration: string;

constructor(declaration: string) {
this.#declaration = declaration;
}

$uses(dependencyMap: Record<string, unknown>): this {
this.#externalsToApply.push(dependencyMap);
if (this.#externals !== undefined) {
throw new Error(
"Cannot call '$uses' multiple times. If you wish to override dependencies, use slots or accessors instead.",
);
}
this.#externals = dependencyMap;
return this;
}

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
const externalMap: ExternalMap = {};

for (const externals of this.#externalsToApply) {
mergeExternals(externalMap, externals);
}

const replacedDeclaration = replaceExternalsInWgsl(ctx, externalMap, this.#declaration);
const replacedDeclaration = replaceExternalsInWgsl(
ctx,
this.#externals ?? {},
this.#declaration,
);

ctx.addDeclaration(replacedDeclaration);
return snip('', Void, /* origin */ 'constant');
Expand Down
75 changes: 55 additions & 20 deletions packages/typegpu/src/core/function/fnCore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,39 @@ import { validateIdentifier } from '../../nameUtils.ts';
import { getFunctionMetadata, getName } from '../../shared/meta.ts';
import { $getNameForward } from '../../shared/symbols.ts';
import type { ResolutionCtx, TgpuShaderStage } from '../../types.ts';
import { mergeExternals, type ExternalMap, replaceExternalsInWgsl } from '../resolve/externals.ts';
import {
type ExternalMap,
replaceExternalsInWgsl,
mergeFunctionExternals,
} from '../resolve/externals.ts';
import { extractArgs } from './extractArgs.ts';
import type { Implementation, SeparatedEntryArgs } from './fnTypes.ts';

export type FnExternals = {
/**
* Externals provided by calling `$uses()`.
* May be nested.
*/
userProvided?: ExternalMap;
/**
* Externals provided by unplugin-typegpu via function metadata.
* May be nested.
*/
pluginProvided?: ExternalMap;
/**
* Function arguments, for example `{ S: Schema }` in `tgpu.fn([Schema])('(arg: S) => {}')`.
* Must be flat (every value must be resolvable).
*/
args?: ExternalMap;
/**
* Function return type, for example `{ Out: ... }` in both rawWgsl entrypoint functions and `vertexFnShell(in, Out)`.
* Must be flat (every value must be resolvable).
*/
out?: ExternalMap;
};

export interface FnCore {
applyExternals: (newExternals: ExternalMap) => void;
setExternals: (key: keyof FnExternals, newExternal: ExternalMap) => void;
resolve(
ctx: ResolutionCtx,
/**
Expand Down Expand Up @@ -43,14 +70,28 @@ export function createFnCore(
* initialized yet (like when accessing the Output struct of a vertex
* entry fn).
*/
const externalsToApply: ExternalMap[] = [];
const externals: FnExternals = {};

const core = {
// Making the implementation the holder of the name, as long as it's
// a function (and not a string implementation)
[$getNameForward]: typeof implementation === 'function' ? implementation : undefined,
applyExternals(newExternals: ExternalMap): void {
externalsToApply.push(newExternals);

setExternals(key: keyof FnExternals, newExternal: ExternalMap): void {
if (key === 'userProvided') {
if ('userProvided' in externals) {
// other external keys may be set multiple times by multiple resolves
throw new Error(
"Cannot call '$uses' multiple times. If you wish to override dependencies, use slots or accessors instead.",
);
}
if ('pluginProvided' in externals) {
throw new Error(
"Cannot call '$uses' on functions whose metadata was provided by unplugin-typegpu.",
);
}
}
externals[key] = newExternal;
},

resolve(
Expand All @@ -59,8 +100,6 @@ export function createFnCore(
returnType: BaseData | undefined,
entryInput?: SeparatedEntryArgs,
): ResolvedSnippet {
const externalMap: ExternalMap = {};

let attributes = '';
if (functionType === 'compute') {
attributes = `@compute @workgroup_size(${workgroupSize?.join(', ')}) `;
Expand All @@ -70,10 +109,6 @@ export function createFnCore(
attributes = `@fragment `;
}

for (const externals of externalsToApply) {
mergeExternals(externalMap, externals);
}

const id = ctx.makeUniqueIdentifier(getName(this), 'global');

if (typeof implementation === 'string') {
Expand All @@ -96,14 +131,18 @@ export function createFnCore(
}
}

mergeExternals(externalMap, {
this.setExternals('args', {
in: Object.fromEntries(
entryInput.positionalArgs.map((a) => [a.schemaKey, a.schemaKey]),
),
});
}

const replacedImpl = replaceExternalsInWgsl(ctx, externalMap, implementation);
const replacedImpl = replaceExternalsInWgsl(
ctx,
mergeFunctionExternals(externals),
implementation,
);

let header = '';
let body = '';
Expand Down Expand Up @@ -175,11 +214,7 @@ export function createFnCore(

const pluginExternals = pluginData?.externals();
if (pluginExternals) {
const missing = Object.fromEntries(
Object.entries(pluginExternals).filter(([name]) => !(name in externalMap)),
);

mergeExternals(externalMap, missing);
this.setExternals('pluginProvided', pluginExternals);
}

const ast = pluginData?.ast;
Expand All @@ -193,7 +228,7 @@ export function createFnCore(
// We look at the identifier chosen by the user and add it to externals.
const maybeSecondArg = ast.params[1];
if (maybeSecondArg && maybeSecondArg.type === 'i' && functionType !== 'normal') {
mergeExternals(externalMap, {
this.setExternals('out', {
// oxlint-disable-next-line typescript/no-non-null-assertion -- entry functions cannot be shellless
[maybeSecondArg.name]: undecorate(returnType!),
});
Expand All @@ -210,7 +245,7 @@ export function createFnCore(
params: ast.params,
returnType,
body: ast.body,
externalMap,
externalMap: mergeFunctionExternals(externals),
});

ctx.addDeclaration(code);
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/core/function/tgpuComputeFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function createComputeFn<ComputeIn extends IORecord<AnyComputeBuiltin>>(
shell,

$uses(newExternals) {
core.applyExternals(newExternals);
core.setExternals('userProvided', newExternals);
return this;
},

Expand Down
6 changes: 3 additions & 3 deletions packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function createFn<ImplSchema extends AnyFn>(
[$internal]: { implementation },

$uses(newExternals: Record<string, unknown>) {
core.applyExternals(newExternals);
core.setExternals('userProvided', newExternals);
return this;
},

Expand All @@ -203,8 +203,8 @@ function createFn<ImplSchema extends AnyFn>(

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
if (typeof implementation === 'string') {
addArgTypesToExternals(implementation, shell.argTypes, core.applyExternals);
addReturnTypeToExternals(implementation, shell.returnType, core.applyExternals);
addArgTypesToExternals(implementation, shell.argTypes, core);
addReturnTypeToExternals(implementation, shell.returnType, core);
}

return core.resolve(ctx, shell.argTypes, shell.returnType);
Expand Down
10 changes: 5 additions & 5 deletions packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,15 @@ function createFragmentFn(
const core = createFnCore(implementation, 'fragment');
const outputType = shell.returnType;
if (typeof implementation === 'string') {
addReturnTypeToExternals(implementation, outputType, (externals) =>
core.applyExternals(externals),
);
addReturnTypeToExternals(implementation, outputType, core);
}

const result: This = {
shell,
outputType,

$uses(newExternals) {
core.applyExternals(newExternals);
core.setExternals('userProvided', newExternals);
return this;
},

Expand All @@ -216,7 +214,9 @@ function createFragmentFn(
if (entryInput.dataSchema && isNamable(entryInput.dataSchema)) {
entryInput.dataSchema.$name(`${getName(this) ?? ''}_Input`);
}
core.applyExternals({ Out: outputType });
if (typeof implementation === 'string') {
core.setExternals('out', { Out: outputType });
}

return ctx.withSlots([[shaderStageSlot, 'fragment']], () =>
core.resolve(ctx, [], shell.returnType, entryInput),
Expand Down
4 changes: 2 additions & 2 deletions packages/typegpu/src/core/function/tgpuVertexFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function createVertexFn(
shell,

$uses(newExternals) {
core.applyExternals(newExternals);
core.setExternals('userProvided', newExternals);
return this;
},

Expand All @@ -182,7 +182,7 @@ function createVertexFn(
);

if (typeof implementation === 'string') {
core.applyExternals({ Out: outputWithLocation });
core.setExternals('out', { Out: outputWithLocation });
}

return ctx.withSlots([[shaderStageSlot, 'vertex']], () =>
Expand Down
20 changes: 9 additions & 11 deletions packages/typegpu/src/core/rawCodeSnippet/tgpuRawCodeSnippet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { inCodegenMode } from '../../execMode.ts';
import type { InferGPU } from '../../shared/repr.ts';
import { $gpuValueOf, $internal, $ownSnippet, $resolve } from '../../shared/symbols.ts';
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
import { mergeExternals, type ExternalMap, replaceExternalsInWgsl } from '../resolve/externals.ts';
import { type ExternalMap, replaceExternalsInWgsl } from '../resolve/externals.ts';
import { valueProxyHandler } from '../valueProxyUtils.ts';

// ----------
Expand Down Expand Up @@ -92,30 +92,28 @@ class TgpuRawCodeSnippetImpl<TDataType extends BaseData>
readonly origin: RawCodeSnippetOrigin;

#expression: string;
#externalsToApply: ExternalMap[];
#externals: ExternalMap | undefined;

constructor(expression: string, type: TDataType, origin: RawCodeSnippetOrigin) {
this[$internal] = true;
this.dataType = type;
this.origin = origin;

this.#expression = expression;
this.#externalsToApply = [];
}

$uses(dependencyMap: Record<string, unknown>): this {
this.#externalsToApply.push(dependencyMap);
if (this.#externals !== undefined) {
throw new Error(
"Cannot call '$uses' multiple times. If you wish to override dependencies, use slots or accessors instead.",
);
}
this.#externals = dependencyMap;
return this;
}

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
const externalMap: ExternalMap = {};

for (const externals of this.#externalsToApply) {
mergeExternals(externalMap, externals);
}

const replacedExpression = replaceExternalsInWgsl(ctx, externalMap, this.#expression);
const replacedExpression = replaceExternalsInWgsl(ctx, this.#externals ?? {}, this.#expression);

return snip(replacedExpression, this.dataType, this.origin);
}
Expand Down
Loading
Loading