Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/typegpu/src/core/buffer/bufferUsage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class TgpuFixedBufferImpl<TData extends BaseData, TUsage extends BindableBufferU
{
[$internal]: true,
get [$ownSnippet]() {
return snip(this, dataType, usage);
return snip(this, dataType, usage, /* possibleSideEffects */ false);
},
[$resolve]: (ctx) => ctx.resolve(this),
toString: () => `${this.usage}:${getName(this) ?? '<unnamed>'}.$`,
Expand Down Expand Up @@ -262,7 +262,7 @@ export class TgpuLaidOutBufferImpl<TData extends BaseData, TUsage extends Bindab
{
[$internal]: true,
get [$ownSnippet]() {
return snip(this, schema, usage);
return snip(this, schema, usage, /* possibleSideEffects */ false);
},
[$resolve]: (ctx) => ctx.resolve(this),
toString: () => `${this.usage}:${getName(this) ?? '<unnamed>'}.$`,
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/core/constant/tgpuConstant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class TgpuConstImpl<TDataType extends BaseData> implements TgpuConst<TDataType>,
{
[$internal]: true,
get [$ownSnippet]() {
return snip(this, dataType, 'constant-immutable-def');
return snip(this, dataType, 'constant-immutable-def', /* possibleSideEffects */ false);
},
[$resolve]: (ctx) => ctx.resolve(this),
toString: () => `const:${getName(this) ?? '<unnamed>'}.$`,
Expand Down
15 changes: 13 additions & 2 deletions packages/typegpu/src/core/function/dualImpl.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
import { type MapValueToSnippet, noSideEffects, snip } from '../../data/snippet.ts';
import { setName } from '../../shared/meta.ts';
import { $gpuCallable } from '../../shared/symbols.ts';
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
Expand Down Expand Up @@ -29,6 +29,12 @@ interface DualImplOptions<T extends AnyFn> {
*/
readonly noComptime?: boolean | undefined;
readonly ignoreImplicitCastWarning?: boolean | undefined;
/**
* Whether the function always has side effects. If `true`, the result always
* has `possibleSideEffects: true` regardless of argument side-effects. If
* `false` (default), the result has side effects only when any argument does.
*/
readonly sideEffects?: boolean | undefined;
}

export class MissingCpuImplError extends Error {
Expand Down Expand Up @@ -101,12 +107,17 @@ export function dualImpl<T extends AnyFn>(options: DualImplOptions<T>): DualFn<T
}
}

return snip(
const result = snip(
options.codegenImpl(ctx, converted),
concretize(returnType),
// Functions give up ownership of their return value
/* origin */ 'runtime',
);

if (!options.sideEffects && !args.some((a) => a.possibleSideEffects)) {
return noSideEffects(result);
}
return result;
Comment thread
pullfrog[bot] marked this conversation as resolved.
Outdated
},
};

Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/resolutionCtx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ function createArgument(
name,
access: () => {
used = true;
return snip(name, type, origin);
return snip(name, type, origin, /* possibleSideEffects */ false);
},
decoratedType: type,
get used() {
Expand Down
11 changes: 11 additions & 0 deletions packages/typegpu/src/std/atomic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@ export const workgroupBarrier = dualImpl({
normalImpl: 'workgroupBarrier is a no-op outside of CODEGEN mode.',
signature: { argTypes: [], returnType: Void },
codegenImpl: () => 'workgroupBarrier()',
sideEffects: true,
Comment thread
pullfrog[bot] marked this conversation as resolved.
});

export const storageBarrier = dualImpl({
name: 'storageBarrier',
normalImpl: 'storageBarrier is a no-op outside of CODEGEN mode.',
signature: { argTypes: [], returnType: Void },
codegenImpl: () => 'storageBarrier()',
sideEffects: true,
});

export const textureBarrier = dualImpl({
name: 'textureBarrier',
normalImpl: 'textureBarrier is a no-op outside of CODEGEN mode.',
signature: { argTypes: [], returnType: Void },
codegenImpl: () => 'textureBarrier()',
sideEffects: true,
});

const atomicNormalError = 'Atomic operations are not supported outside of CODEGEN mode.';
Expand Down Expand Up @@ -72,53 +75,61 @@ export const atomicStore = dualImpl<<T extends AnyAtomic>(a: T, value: number) =
normalImpl: atomicNormalError,
signature: atomicActionSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicStore(&${a}, ${value})`,
sideEffects: true,
});

export const atomicAdd = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicAdd',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicAdd(&${a}, ${value})`,
sideEffects: true,
});

export const atomicSub = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicSub',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicSub(&${a}, ${value})`,
sideEffects: true,
});

export const atomicMax = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicMax',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicMax(&${a}, ${value})`,
sideEffects: true,
});

export const atomicMin = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicMin',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicMin(&${a}, ${value})`,
sideEffects: true,
});

export const atomicAnd = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicAnd',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicAnd(&${a}, ${value})`,
sideEffects: true,
});

export const atomicOr = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicOr',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicOr(&${a}, ${value})`,
sideEffects: true,
});

export const atomicXor = dualImpl<<T extends AnyAtomic>(a: T, value: number) => number>({
name: 'atomicXor',
normalImpl: atomicNormalError,
signature: atomicOpSignature,
codegenImpl: (_ctx, [a, value]) => stitch`atomicXor(&${a}, ${value})`,
sideEffects: true,
});
1 change: 1 addition & 0 deletions packages/typegpu/src/std/texture.ts
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ export const textureStore = dualImpl({
normalImpl: textureStoreCpu,
codegenImpl: (_ctx, args) => stitch`textureStore(${args})`,
signature: (...args) => ({ argTypes: args, returnType: Void }),
sideEffects: true,
});

function textureDimensionsCpu<T extends texture1d | textureStorage1d>(texture: T): number;
Expand Down
3 changes: 3 additions & 0 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ ${this.ctx.pre}}`;
type,
// Result of an operation, so not a reference to anything
/* origin */ 'runtime',
exprType === NODE.assignmentExpr ||
lhsExpr.possibleSideEffects ||
rhsExpr.possibleSideEffects,
);
}

Expand Down
10 changes: 5 additions & 5 deletions packages/typegpu/tests/tgsl/ternaryOperator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,18 @@ describe('ternary operator', () => {
`);
});

it('should throw when test is not comptime known', () => {
it('should generate select() for runtime condition with function params', () => {
const myFn = tgpu.fn(
[d.u32],
d.u32,
)((n) => {
return n > 0 ? n : -n;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Ternary operator '(n > 0) ? n : (-n)' is invalid. For more complex branching, please use 'std.select' or if/else statements.]
expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn(n: u32) -> u32 {
return select(-(n), n, (n > 0u));
}"
`);
});
});
165 changes: 165 additions & 0 deletions packages/typegpu/tests/tgsl/ternaryRuntime.test.ts
Comment thread
pullfrog[bot] marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import { describe, expect } from 'vitest';
import { it } from 'typegpu-testing-utility';
import tgpu, { d } from '../../src/index.js';

describe('runtime ternary operator', () => {
it('should handle subtraction in branches with function params', () => {
const myFn = tgpu.fn(
[d.u32, d.u32],
d.u32,
)((b, w) => {
return b > w ? b - w : w - b;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn(b: u32, w: u32) -> u32 {
return select((w - b), (b - w), (b > w));
}"
`);
});

it('should handle const array indexing in branches', () => {
const RotLut2Gpu = tgpu.const(d.arrayOf(d.u32, 2), [10, 20]);
const RotLut3Gpu = tgpu.const(d.arrayOf(d.u32, 3), [30, 40, 50]);

const myFn = tgpu.fn(
[d.u32, d.u32],
d.u32,
)((r, bitU) => {
return r === d.u32(2) ? (RotLut2Gpu.$[bitU] as number) : (RotLut3Gpu.$[bitU] as number);
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"const RotLut2Gpu: array<u32, 2> = array<u32, 2>(10u, 20u);

const RotLut3Gpu: array<u32, 3> = array<u32, 3>(30u, 40u, 50u);

fn myFn(r: u32, bitU: u32) -> u32 {
return select(RotLut3Gpu[bitU], RotLut2Gpu[bitU], (r == 2u));
}"
`);
});

it('should handle nested runtime ternaries', () => {
const myFn = tgpu.fn(
[d.u32, d.u32, d.u32, d.u32],
d.u32,
)((r, v1, v2, v3) => {
return r === d.u32(1) ? v1 : r === d.u32(2) ? v2 : v3;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn(r: u32, v1: u32, v2: u32, v3: u32) -> u32 {
return select(select(v3, v2, (r == 2u)), v1, (r == 1u));
}"
`);
});

it('should handle bit shift in branch with function param', () => {
const myFn = tgpu.fn(
[d.bool],
d.u32,
)((isCustom) => {
return isCustom ? d.u32(1) << d.u32(20) : d.u32(0);
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn(isCustom: bool) -> u32 {
return select(0u, (1u << 20u), isCustom);
}"
`);
});

it('should handle struct field access across ternaries', () => {
const Cw = d.struct({
low: d.u32,
high: d.u32,
});

const myFn = tgpu.fn(
[d.bool, Cw, Cw],
d.u32,
)((isCustom, customCw, stdCw) => {
return isCustom ? customCw.low : stdCw.low;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"struct Cw {
low: u32,
high: u32,
}

fn myFn(isCustom: bool, customCw: Cw, stdCw: Cw) -> u32 {
return select(stdCw.low, customCw.low, isCustom);
}"
`);
});

it('should handle buffer layout access in ternary branches', ({ root }) => {
const Cw = d.struct({
low: d.u32,
high: d.u32,
});

const Layout = d.struct({
codewords: d.arrayOf(Cw, 64),
});

const layout = root.createUniform(Layout, d.ref);

const myFn = tgpu.fn(
[d.bool, d.u32],
d.u32,
)((isCustom, cwIdx) => {
return isCustom ? layout.$.codewords[cwIdx]!.low : layout.$.codewords[cwIdx + d.u32(1)]!.low;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"struct Cw {
low: u32,
high: u32,
}

struct Layout {
codewords: array<Cw, 64>,
}

@group(0) @binding(0) var<uniform> layout_1: Layout;

fn myFn(isCustom: bool, cwIdx: u32) -> u32 {
return select(layout_1.codewords[(cwIdx + 1u)].low, layout_1.codewords[cwIdx].low, isCustom);
}"
`);
});

it('should handle ternary with comparison and unary negation in branches', () => {
const myFn = tgpu.fn(
[d.u32],
d.u32,
)((n) => {
return n > 0 ? n : -n;
});

expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn(n: u32) -> u32 {
return select(-(n), n, (n > 0u));
}"
`);
});

it('should throw when a ternary branch contains an assignment', () => {
const myFn = tgpu.fn(
[d.i32],
d.i32,
)((a) => {
let b = 0;
return a > 0 ? (b = a) : 0;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Ternary operator '(a > 0) ? (b = a) : 0' is invalid. For more complex branching, please use 'std.select' or if/else statements.]
`);
});
});
Loading