-
Notifications
You must be signed in to change notification settings - Fork 34
Draft of Metal driver api support #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
87dda3a
ee435ba
6967d3a
48af07a
c8d279b
90de581
061b96b
954408d
58b8308
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| module dcompute.driver.metal.buffer; | ||
| import metal; | ||
| import dcompute.driver.metal.program; | ||
| import dcompute.driver.metal; | ||
| import core.stdc.string; | ||
|
|
||
| struct Buffer(T) | ||
| { | ||
| MTLBuffer mtlBuffer; | ||
|
|
||
| // Host memory associated with this buffer | ||
| T[] hostMemory; | ||
|
|
||
| this(MTLBuffer _mtlBuffer, T[] array) | ||
| { | ||
| mtlBuffer = _mtlBuffer; | ||
| hostMemory = array; | ||
| } | ||
|
|
||
| T* contents() | ||
| { | ||
| return cast(T*) mtlBuffer.contents(); | ||
| } | ||
|
|
||
| void release() | ||
| { | ||
| mtlBuffer = null; | ||
| hostMemory = null; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| module dcompute.driver.metal.device; | ||
| import dcompute.driver.metal.buffer; | ||
| import core.stdc.string; | ||
| import metal; | ||
|
|
||
| struct Device | ||
| { | ||
| /** | ||
| A pointer to $(D MTLDevice). It is $(D void*) because upon storing array of $(D Device), | ||
| linker look for the $(D MTLDevice) but fails to | ||
| find it as it is Objective-C binding hence had to wrap it as such | ||
| */ | ||
| void* raw; | ||
|
|
||
| @property MTLDevice mtlDevice() | ||
| { | ||
| return cast(MTLDevice) raw; | ||
| } | ||
|
|
||
| this(MTLDevice device) | ||
| { | ||
| raw = cast(void*)device; | ||
| } | ||
|
|
||
| MTLBuffer newBuffer(size_t sizeInBytes) | ||
| { | ||
| return mtlDevice.newBuffer(sizeInBytes, MTLResourceOptions.StorageModeShared); | ||
| } | ||
|
|
||
| Buffer!T makeBuffer(T)(T[] hostMemory) | ||
| { | ||
| size_t sizeInBytes = hostMemory.length * T.sizeof; | ||
|
|
||
| auto mtlBuffer = newBuffer(sizeInBytes); | ||
| auto buffer = Buffer!T(mtlBuffer, hostMemory); | ||
|
|
||
| if (buffer.hostMemory.ptr !is null && sizeInBytes > 0) | ||
| { | ||
| memcpy(buffer.mtlBuffer.contents(), buffer.hostMemory.ptr, sizeInBytes); | ||
| } | ||
|
|
||
| return buffer; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| module dcompute.driver.metal.kernel; | ||
| import metal.library; | ||
|
|
||
| struct Kernel(F) if (is(F==function) || is(F==void)) | ||
| { | ||
| MTLFunction kernelFunction; | ||
|
|
||
| this(MTLFunction _kernelFunction) | ||
| { | ||
| kernelFunction = _kernelFunction; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| module dcompute.driver.metal; | ||
| import ldc.dcompute; | ||
| import std.range; | ||
| import std.meta; | ||
| import std.traits; | ||
|
|
||
| public import dcompute.driver.metal.buffer; | ||
| public import dcompute.driver.metal.device; | ||
| public import dcompute.driver.metal.kernel; | ||
| public import dcompute.driver.metal.platform; | ||
| public import dcompute.driver.metal.program; | ||
| public import dcompute.driver.metal.queue; | ||
|
|
||
|
|
||
| template HostArgsOf(F) | ||
| { | ||
| template toBuffer(T) | ||
| { | ||
| static if (is(T: Pointer!(n,U), uint n, U)) | ||
| alias toBuffer = Buffer!U; | ||
| else | ||
| alias toBuffer = T; | ||
| } | ||
|
|
||
| alias HostArgsOf = staticMap!(toBuffer, Parameters!F); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| module dcompute.driver.metal.platform; | ||
|
|
||
| import dcompute.driver.metal.device; | ||
| import metal.device; | ||
|
|
||
| struct Platform | ||
| { | ||
| static Device[] getDevices() | ||
| { | ||
| auto mtlDevices = MTLCopyAllDevices(); | ||
| auto devices = new Device[mtlDevices.length]; | ||
|
|
||
| for(int i=0;i < mtlDevices.length;i ++) | ||
| { | ||
| devices[i] = Device(mtlDevices[i]); | ||
| } | ||
|
|
||
| return devices; | ||
| } | ||
|
|
||
| static Device getDefaultDevice() | ||
| { | ||
| auto device = Device(MTLCreateSystemDefaultDevice()); | ||
| return device; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| module dcompute.driver.metal.program; | ||
| import dcompute.driver.metal.device; | ||
| import dcompute.driver.metal.kernel; | ||
| import objc; | ||
| import foundation; | ||
| import core.stdc.stdio; | ||
| import std.string; | ||
| import std.path; | ||
| import metal.library; | ||
| import metal.device; | ||
|
|
||
| struct Program | ||
| { | ||
| MTLLibrary metalLibrary; | ||
|
|
||
| Device device; | ||
|
|
||
| Kernel!void getKernelByName(immutable(char)* name) | ||
| { | ||
| auto kName = fromStringz(name); | ||
|
|
||
| auto kNameInNSString = NSString.create(kName); | ||
|
|
||
| auto kernelFunction = metalLibrary.newFunctionWithName(kNameInNSString); | ||
|
|
||
| if (kernelFunction is null) | ||
| { | ||
| printf("Error: Could not find kernel function %s in library.\n", name); | ||
| assert(0); | ||
| } | ||
|
|
||
| return Kernel!void(kernelFunction); | ||
| } | ||
|
|
||
| Kernel!(typeof(k)) getKernel(alias k)() | ||
| { | ||
| return cast(typeof(return)) getKernelByName(k.mangleof.ptr); | ||
| } | ||
|
|
||
| static Program fromFile(Device device, string path) | ||
| { | ||
| NSError error; | ||
| auto nsPath = NSString.create(absolutePath(path)); | ||
|
|
||
| auto library = device.mtlDevice.newLibrary(NSURL.fromPath(nsPath), error); | ||
|
|
||
| if (library is null) | ||
| { | ||
| printf("Error loading .metallib: %s\n", error.localizedDescription().ptr); | ||
| assert(0); | ||
| } | ||
|
|
||
| return Program(library, device); | ||
| } | ||
|
|
||
| __gshared static Program globalProgram; | ||
|
|
||
| void unload() | ||
| { | ||
| metalLibrary = null; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| module dcompute.driver.metal.queue; | ||
| import dcompute.driver.metal.buffer; | ||
|
|
||
| import dcompute.driver.metal; | ||
| import dcompute.driver.metal.device; | ||
| import dcompute.driver.metal.program; | ||
| import metal; | ||
| import metal.argument; | ||
| import metal.types; | ||
| import core.stdc.stdio; | ||
| import objc; | ||
| import foundation; | ||
|
|
||
| struct Queue | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a pipeline state seems much more associated with a Queue.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes agreed, I moved it to Queue as it fits better there |
||
| { | ||
| Device device; | ||
| MTLCommandQueue commandQueue; | ||
| MTLCommandBuffer lastActiveBuffer; | ||
|
|
||
| // TODO(asadbek): explore options to make the use of async execution with events | ||
| this (Device _device /*bool async*/) | ||
| { | ||
| device = _device; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. device is unused outside of the constructor, why cache a reference to it in the Queue?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know how to encode it in other way, I used |
||
| commandQueue = device.mtlDevice.newCommandQueue(); | ||
| } | ||
|
|
||
| auto enqueue(alias k)(uint[3] _grid, uint[3] _block) | ||
| { | ||
| static struct Call | ||
| { | ||
| Queue* q; | ||
| uint[3] grid, block; | ||
|
|
||
| this(Queue* _q, uint[3] _grid, uint[3] _block) | ||
| { | ||
| q = _q; | ||
| grid = _grid; | ||
| block = _block; | ||
| } | ||
|
|
||
| void opCall(HostArgsOf!(typeof(k)) args) | ||
| { | ||
| NSError error; | ||
|
|
||
| auto kernel = Program.globalProgram.getKernel!k(); | ||
|
|
||
| auto pipelineState = q.device.mtlDevice.newComputePipelineStateWithFunction( | ||
| kernel.kernelFunction, | ||
| MTLPipelineOption.None, | ||
| null, | ||
| error | ||
| ); | ||
|
|
||
| if (pipelineState is null) | ||
| { | ||
| printf("Error: Backend compilation failed: %s\n", error.localizedDescription().ptr); | ||
| assert(0); | ||
| } | ||
|
|
||
| auto commandBuffer = q.commandQueue.commandBuffer(); | ||
|
|
||
| auto computeEncoder = commandBuffer.computeCommandEncoder(); | ||
|
|
||
| computeEncoder.setComputePipelineState(pipelineState); | ||
|
|
||
| foreach (i, arg; args) | ||
| { | ||
| static if (is(typeof(arg): Buffer!U, U)) | ||
| { | ||
| computeEncoder.setBuffer(arg.mtlBuffer, 0, i); | ||
| } else static if (__traits(isScalar, typeof(arg))) | ||
| { | ||
| computeEncoder.setBytes(&arg, typeof(arg).sizeof, i); | ||
| } | ||
| else | ||
| { | ||
| static assert(0, "Unsupported argument type for Metal kernel dispatch!"); | ||
| } | ||
| } | ||
|
|
||
| auto threadgroupsPerGrid = MTLSize(grid[0], grid[1], grid[2]); | ||
|
|
||
| auto threadsPerThreadgroup = MTLSize(block[0], block[1], block[2]); | ||
|
|
||
| computeEncoder.dispatchThreads(threadgroupsPerGrid, threadsPerThreadgroup); | ||
|
|
||
| computeEncoder.endEncoding(); | ||
| commandBuffer.commit(); | ||
|
|
||
| q.lastActiveBuffer = commandBuffer; | ||
| } | ||
| } | ||
|
|
||
| return Call(&this, _grid, _block); | ||
| } | ||
|
|
||
| void finish() { | ||
| if (lastActiveBuffer !is null) { | ||
| lastActiveBuffer.waitUntilCompleted(); | ||
| lastActiveBuffer.release(); | ||
|
|
||
| lastActiveBuffer = null; | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does the program contain a device?
The program represents the code.
It seems to me that you should split out the concept of a kernel from its binding to a pipeline state object.