Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebGPU-based successor for GPUjs #822

Open
joshbrew opened this issue Jul 4, 2023 · 7 comments
Open

WebGPU-based successor for GPUjs #822

joshbrew opened this issue Jul 4, 2023 · 7 comments

Comments

@joshbrew
Copy link

joshbrew commented Jul 4, 2023

Okay fellas, we have compute support now in native web, and outputting kernels are possible e.g. this tutorial: https://jott.live/markdown/webgpu_safari

Who has started or where do we want to start the discussion of a WebGPU-based spiritual successor for GPUjs? The whole javascript-based kernel generation is genius for learning GPU coding and helped me immensely since the performance is fine too, and it's not really that complicated under the hood since it's just transposing text for you. The WebGPU compute pipeline adds some additional boilerplate for setting up command/storage buffers and data structures though it could definitely be macro'd based on what I've seen.

Another thing to think about is chaining compute and fragment shaders, though compute can handle rasterization and possibly faster, while the fragment shader can simply dump the resulting image matrix to screen. E.g. https://github.com/OmarShehata/webgpu-compute-rasterizer/blob/main/how-to-build-a-compute-rasterizer.md

@joshbrew
Copy link
Author

joshbrew commented Aug 24, 2023

Hey all, so GPT4 and me got to talkin and, I got a start on this. Here is a sample converting a DFT function. It needs a lot of work, like inferring input/outputs and so on as it's a little jank, but here's to a start :-), give me a couple more days to think and I'm parking this.

Just save this as an html file, open, and check the console/see the readout, scroll to the bottom of the code to see the javascript function being converted.

Turns out webgpu is really not that challenging to transpose. There are some challenges in how to infer types and so on but I will have something in place shortly just out of my own curiosity.

<html>
    <head></head>
    <body>
        <script>
class WebGPUjs {
    constructor() {
        this.bindings = [];
    }

    static async createPipeline(computeFunction, device = null) {
        if (!device) {
            const gpu = navigator.gpu;
            const adapter = await gpu.requestAdapter();
            device = await adapter.requestDevice();
        }
        const processor = new WebGPUjs();

        let shader = computeFunction; let ast;
        if(typeof computeFunction === 'function') {
            let result = processor.convertToWebGPU(computeFunction); 
            shader = result.shader; ast = result.ast;
        }
        await processor.init(shader, ast, undefined, device);
        return processor;
    }

    async init(computeShader, ast, bindGroupLayoutSettings, device=this.device) {
        this.device = device;


        // Extract all returned variables from the function string
        const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
        let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;

        returnedVars = this.flattenStrings(returnedVars);

        if(ast) {
            let bufferIncr = 0;
            let uniformBufferIdx;
            let filtered = ast.filter((v) => v.isInput || returnedVars?.includes(v.name));
            const entries = filtered.map((node, i) => {
                let isReturned = (returnedVars === undefined || returnedVars?.includes(node.name));
                if (node.isUniform) {
                    if(typeof uniformBufferIdx === 'undefined') {
                        uniformBufferIdx = i; 
                        bufferIncr++; 
                        return {
                            binding: uniformBufferIdx,
                            visibility: GPUShaderStage.COMPUTE,
                            buffer: {
                                type: 'uniform'
                            }
                        };
                    }
                    return undefined;
                }
                else {
                    const buffer = {
                        binding: bufferIncr,
                        visibility: GPUShaderStage.COMPUTE,
                        buffer: {
                            type: (isReturned || node.isModified) ? 'storage' : 'read-only-storage' // Corrected here
                        }
                    };
                    bufferIncr++;
                    return buffer;
                }
            }).filter(v => v);

            this.bindGroupLayout = this.device.createBindGroupLayout({
                entries
            });
        }
        else if (bindGroupLayout) this.bindGroupLayout = this.device.createBindGroupLayout(bindGroupLayoutSettings);


        this.pipelineLayout = this.device.createPipelineLayout({
            bindGroupLayouts: [this.bindGroupLayout]
        });

        this.shader = computeShader;
        this.shaderModule = this.device.createShaderModule({
            code:computeShader
        });

        this.computePipeline = this.device.createComputePipeline({
            layout: this.pipelineLayout,
            compute: {
                module: this.shaderModule,
                entryPoint: 'main'
            }
        });

        return this.computePipeline;
    }

    // Helper function to determine the type and size of the input
    getInputTypeInfo(input, idx) {

        const typeName = input.constructor.name;
        
        // Check the variable registry for the type
        const param = this.params[idx];
        if (param) {
            if (param.type.startsWith('mat')) {
                const matMatch = param.type.match(/mat(\d+)x(\d+)<(f32|i32)>/);
                const rows = parseInt(matMatch[1]);
                const cols = parseInt(matMatch[2]);
                const type = matMatch[3];
                
                return { type: param.type, byteSize: wgslTypeSizes[param.type].size };
            } else if (param.type.startsWith('vec')) {
                const vecMatch = param.type.match(/vec(\d+)<(f32|i32)>/);
                const dimensions = parseInt(vecMatch[1]);
                const type = vecMatch[2];
                //console.log(param, wgslTypeSizes[param.type])
                return { type: `vec${dimensions}<${type}>`, byteSize: wgslTypeSizes[param.type].size };
            }
        }


        switch (typeName) {
            case 'Float32Array':
                return { type: 'f32', byteSize: 4 };
            case 'Int32Array':
                return { type: 'i32', byteSize: 4 };
            case 'Uint32Array':
                return { type: 'u32', byteSize: 4 };
                
            //none of these are supported in webgpu
            case 'Float64Array':
                return { type: 'f64', byteSize: 8 };
            case 'Float16Array': //does not exist in javascript
                return { type: 'f16', byteSize: 2 };
            case 'Int16Array':
                return { type: 'i16', byteSize: 2 };
            case 'Uint16Array':
                return { type: 'u16', byteSize: 2 };
            case 'Int8Array':
                return { type: 'i8', byteSize: 1 };
            case 'Uint8Array':
                return { type: 'u8', byteSize: 1 };
        }

        if (typeof input === 'number') {
            if (Number.isInteger(input)) {
                return { type: 'i32', byteSize: 4 }; //u32??
            } else {
                return { type: 'f32', byteSize: 4 };
            }
        } 

        // Add more conditions for matrices and other types if needed
        return { type: 'unknown', byteSize: 0 };
    }

    flattenArray(arr) {
        let result = [];
        for (let i = 0; i < arr.length; i++) {
            if (Array.isArray(arr[i])) {
                result = result.concat(this.flattenArray(arr[i]));
            } else {
                result.push(arr[i]);
            }
        }
        return result;
    }

    process(...inputs) {

        const inputTypes = [];
        inputs.forEach((input, idx) => {
            inputTypes.push(this.getInputTypeInfo(input, idx))
        })

        const allSameSize = this.inputBuffers && inputs.every((inputArray, index) => 
            this.inputBuffers[index].byteLength === inputArray.length * inputTypes[index].byteSize
        );

        if (!allSameSize) {
            // Create or recreate input buffers      // Extract all returned variables from the function string
            // Separate input and output AST nodes

            this.inputBuffers = [];
            this.uniformBuffer = undefined;
            this.outputBuffers = [];

        }

        let uBufferPushed = false;
        let inputBufferIndex = 0;
        let hasUniformBuffer = 0;
        this.params.forEach((node, i) => {
            if(node.isUniform) {
                // Assuming you've determined the total size of the uniform buffer beforehand
                if (!this.uniformBuffer) {

                    let totalUniformBufferSize = 0;
                    this.ast.forEach((node,j) => {
                        if(node.isInput && node.isUniform){
                            totalUniformBufferSize += inputTypes[j].byteSize;
                            if(totalUniformBufferSize % 8 !== 0) 
                                totalUniformBufferSize += wgslTypeSizes[inputTypes[j].type].alignment;
                        }
                    }); 

                    totalUniformBufferSize -= totalUniformBufferSize % 16; //correct final buffer size (IDK)
                    
                    this.uniformBuffer = this.device.createBuffer({
                        size: totalUniformBufferSize, // This should be the sum of byte sizes of all uniforms
                        usage: GPUBufferUsage.UNIFORM  | GPUBufferUsage.COPY_SRC,
                        mappedAtCreation: true
                    });
                    this.inputBuffers.push(this.uniformBuffer);
                }
                if(!hasUniformBuffer) {
                    hasUniformBuffer = 1;
                    inputBufferIndex++;
                }
            }
            // Create or recreate input buffers
            else {
                if (!allSameSize) {
                    if(!inputs[i]) {
                        if(i > inputs.length)
                        {   //temp, we don't hae a way to estimate dynamically generated data structures
                            if(node.type.includes('vec') || node.type.includes('mat')) //these are immutable anyway so this is kind of useless but we are just padding the uniform buffer
                                inputs[i] = new Float32Array(new Array(16).fill(0)); //e.g. a mat4 
                            else if (node.type.includeS('array'))  
                                inputs[i] = new Float32Array(new Array(65536).fill(0)); //e.g. a dynamic float32 arraybuffer 
                            else inputs[i] = 0.0; //a numbe
                        }
                        throw new Error("Missing Input at argument "+i+". Type: "+this.params[i].type);
                    }
                    if(!inputs[i].byteLength && Array.isArray(inputs[i][0])) inputs[i] = this.flattenArray(inputs[i]);
                    this.inputBuffers.push(
                        this.device.createBuffer({
                            size: inputs[i].byteLength ? inputs[i].byteLength : inputs[i].length*4,
                            usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
                            mappedAtCreation: true
                        })  
                    );
                }

                new Float32Array(this.inputBuffers[inputBufferIndex].getMappedRange()).set(inputs[i]);
                this.inputBuffers[inputBufferIndex].unmap();
                inputBufferIndex++;
            }

            if(node.isReturned && (!node.isUniform || (node.isUniform && !uBufferPushed))) {
                // Create or recreate the output buffers for all returned variables
                if(!node.isUniform) {
                    this.outputBuffers.push(this.inputBuffers[this.inputBuffers.length - 1]);
                } else if(!uBufferPushed) {
                    uBufferPushed = true;
                    this.outputBuffers.push(this.uniformBuffer);
                }
            }
        
        });

        if(this.uniformBuffer) {

            // Use a DataView to set values at specific byte offsets
            const dataView = new DataView(this.uniformBuffer.getMappedRange());

            let offset = 0; // Initialize the offset

            this.ast.forEach((node, i) => {
                if(node.isUniform && node.isReturned) {
                    const typeInfo = wgslTypeSizes[inputTypes[i].type];

                    // Ensure the offset is aligned correctly
                    offset = Math.ceil(offset / typeInfo.alignment) * typeInfo.alignment;
                    if (inputTypes[i].type.startsWith('vec')) {
                        const vecSize = typeInfo.size / 4;
                        for (let j = 0; j < vecSize; j++) {
                            //console.log(dataView,offset + j * 4)
                            dataView.setFloat32(offset + j * 4, inputs[i][j], true);
                        }
                    } else if (inputTypes[i].type.startsWith('mat')) {
                        const flatMatrix = this.flattenArray(inputs[i]);
                        for (let j = 0; j < flatMatrix.length; j++) {
                            dataView.setFloat32(offset + j * 4, flatMatrix[j], true); //we don't have Float16 in javascript :-\
                        }
                    } else{
                        switch (inputTypes[i].type) {
                            case 'f32':
                                dataView.setFloat32(offset, inputs[i], true); // true for little-endian
                                break;
                            case 'i32':
                                dataView.setInt32(offset, inputs[i], true); // true for little-endian
                                break;
                            case 'u32':
                                dataView.setUInt32(offset, inputs[i], true); // true for little-endian
                                break;
                            
                        }
                    }

                    offset += typeInfo.size; // Increment the offset by the size of the type
                }
                
            });


            this.uniformBuffer.unmap();
        }

        if(!allSameSize) {
            // Update bind group creation to include both input and output buffers
            const bindGroupEntries = [...this.inputBuffers].map((buffer, index) => ({
                binding: index,
                resource: { buffer }
            })); //we are inferring outputBuffers from inputBuffers

            this.bindGroup = this.device.createBindGroup({
                layout: this.bindGroupLayout,
                entries: bindGroupEntries
            });
        }
        

        const commandEncoder = this.device.createCommandEncoder();
        const passEncoder = commandEncoder.beginComputePass();

        passEncoder.setPipeline(this.computePipeline);
        passEncoder.setBindGroup(0, this.bindGroup);
        passEncoder.dispatchWorkgroups(Math.ceil(inputs[0].length / 64)); // Assuming all inputs are of the same size

        passEncoder.end();

        // Create staging buffers for all output buffers
        const stagingBuffers = this.outputBuffers.map(outputBuffer => {
            return this.device.createBuffer({
                size: outputBuffer.size,
                usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
            });
        });

        // Copy data from each output buffer to its corresponding staging buffer
        this.outputBuffers.forEach((outputBuffer, index) => {
            commandEncoder.copyBufferToBuffer(
                outputBuffer, 0,
                stagingBuffers[index], 0,
                outputBuffer.size
            );
        });

        this.device.queue.submit([commandEncoder.finish()]);

        const promises = stagingBuffers.map(buffer => {
            return new Promise((resolve) => {
                buffer.mapAsync(GPUMapMode.READ).then(() => {
                    const mappedRange = buffer.getMappedRange();
                    const rawResults = new Float32Array(mappedRange); 
                    const copiedResults = new Float32Array(rawResults.length);
                    
                    copiedResults.set(rawResults); // Fast copy
                    buffer.unmap();
                    resolve(copiedResults);
                });
            });
        });

        return promises.length === 1 ? promises[0] : Promise.all(promises);
    }

    getFunctionHead = (methodString) => {
        let startindex = methodString.indexOf('=>')+1;
        if(startindex <= 0) {
            startindex = methodString.indexOf('){');
        }
        if(startindex <= 0) {
            startindex = methodString.indexOf(') {');
        }
        return methodString.slice(0, methodString.indexOf('{',startindex) + 1);
    }

    splitIgnoringBrackets = (str) => {
        const result = [];
        let depth = 0; // depth of nested structures
        let currentToken = '';

        for (let i = 0; i < str.length; i++) {
            const char = str[i];

            if (char === ',' && depth === 0) {
                result.push(currentToken);
                currentToken = '';
            } else {
                currentToken += char;
                if (char === '(' || char === '[' || char === '{') {
                    depth++;
                } else if (char === ')' || char === ']' || char === '}') {
                    depth--;
                }
            }
        }

        // This is the change: Ensure any remaining content in currentToken is added to result
        if (currentToken) {
            result.push(currentToken);
        }

        return result;
    }

    tokenize(funcStr) {
        // Capture function parameters
        let head = this.getFunctionHead(funcStr);
        let paramString = head.substring(head.indexOf('(') + 1, head.lastIndexOf(')'));
        let params = this.splitIgnoringBrackets(paramString).map(param => ({
            token: param,
            isInput: true
        }));
        
        // Capture variables, arrays, and their assignments
        const assignmentTokens = (funcStr.match(/(const|let|var)\s+(\w+)\s*=\s*([^;]+)/g) || []).map(token => ({
            token,
            isInput: false
        }));


        // Combine both sets of tokens
        return params.concat(assignmentTokens);
    }

    parse = (tokens) => {
        const ast = [];

        // Extract all returned variables from the tokens
        const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
        let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;

        returnedVars = this.flattenStrings(returnedVars);


        const functionBody = this.fstr.substring(this.fstr.indexOf('{')); 
        //basic function splitting, we dont support object inputs right now, anyway. e.g. we could add {x,y,z} objects to define vectors

        tokens.forEach(({ token, isInput }) => {
            let isReturned = returnedVars?.find((v) => token.includes(v));
            let isModified = new RegExp(`\\b${token.split('=')[0]}\\b(\\[\\w+\\])?\\s*=`).test(functionBody);

            if (token.includes('=')) {
                const variableMatch = token.match(/(const|let|var)?\s*(\w+)\s*=\s*(.+)/);
                if (variableMatch[3].startsWith('new Array') || variableMatch[3].startsWith('[')) {
                    ast.push({
                        type: 'array',
                        name: variableMatch[2],
                        value: variableMatch[3],
                        isInput,
                        isReturned: returnedVars ? returnedVars?.includes(variableMatch[2]) : isInput ? true : false,
                        isModified
                    });
                } else if (token.startsWith('vec') || token.startsWith('mat')) {
                    const typeMatch = token.match(/(vec\d|mat\d+x\d+)\(([^)]+)\)/);
                    if (typeMatch) {
                        ast.push({
                            type: typeMatch[1],
                            name: token.split('=')[0],
                            value: typeMatch[2],
                            isInput,
                            isReturned: returnedVars ? returnedVars?.includes(token.split('=')[0]) : isInput ? true : false,
                            isModified
                        });
                    }
            }   else {
                    ast.push({
                        type: 'variable',
                        name: variableMatch[2],
                        value: variableMatch[3],
                        isUniform:true,
                        isInput,
                        isReturned: returnedVars ? returnedVars?.includes(variableMatch[2]) : isInput ? true : false,
                        isModified
                    });
                }
            } else if (token.includes('new Array') || token.includes('[')) {
                // This is a function parameter that is an array
                const paramName = token.split('=')[0];
                ast.push({
                    type: 'array',
                    name: paramName,
                    value: token,
                    isInput,
                    isReturned,
                    isModified
                });
            } else if (token.startsWith('vec') || token.startsWith('mat')) {
                const typeMatch = token.match(/(vec\d|mat\d+x\d+)\(([^)]+)\)/);
                if (typeMatch) {
                    ast.push({
                        type: typeMatch[1],
                        name: token.split('=')[0],
                        value: typeMatch[2],
                        isInput,
                        isReturned: returnedVars ? returnedVars?.includes(token.split('=')[0]) : isInput ? true : false,
                        isModified
                    });
                }
            }   else {
                // This is a function parameter without a default value
                ast.push({
                    type: 'variable',
                    name: token,
                    value: 'unknown',
                    isUniform:true,
                    isInput,
                    isReturned,
                    isModified
                });
            }
        });
        this.ast = ast;


        return ast;
    }

    inferTypeFromValue(value, funcStr, ast) {
        if (value.startsWith('vec')) {
            const type = value.includes('.') ? '<f32>' : '<i32>';
            return value.match(/vec(\d)/)[0] + type;
        } else if (value.startsWith('mat')) {
            const type = value.includes('.') ? '<f32>' : '<i32>';
            return value.match(/mat(\d)x(\d)/)[0] + type;
        } else if (value.startsWith('[')) {
            // Infer the type from the first element if the array is initialized with values
            const firstElement = value.split(',')[0].substring(1);
            if(firstElement === ']') return 'array<f32>';
            if (firstElement.startsWith('[') && !firstElement.endsWith(']')) {
                // Only recurse if the first element is another array and not a complete array by itself
                return this.inferTypeFromValue(firstElement, funcStr, ast);
            } else {
                // Check if the first element starts with vec or mat
                if (firstElement.startsWith('vec') || firstElement.startsWith('mat')) {
                    return `array<${this.inferTypeFromValue(firstElement, funcStr, ast)}>`;
                } else if (firstElement.includes('.')) {
                    return 'array<f32>';
                } else if (!isNaN(firstElement)) {
                    return 'array<i32>';
                }
            }
        } else if (value.startsWith('new Array')) {
            // If the array is initialized using the `new Array()` syntax, look for assignments in the function body
            const arrayNameMatch = value.match(/let\s+(\w+)\s*=/);
            if (arrayNameMatch) {
                const arrayName = arrayNameMatch[1];
                const assignmentMatch = funcStr.match(new RegExp(`${arrayName}\\[\\d+\\]\\s*=\\s*(.+?);`));
                if (assignmentMatch) {
                    return this.inferTypeFromValue(assignmentMatch[1], funcStr, ast);
                }
            } else return 'f32'
        } else if (value.includes('.')) {
            return 'f32';  // Float type for values with decimals
        } else if (!isNaN(value)) {
            return 'i32';  // Int type for whole numbers
        } else {
             // Check if the value is a variable name and infer its type from AST
            const astNode = ast.find(node => node.name === value);
            if (astNode) {
                if (astNode.type === 'array') {
                    return 'f32';  // Assuming all arrays are of type f32 for simplicity
                } else if (astNode.type === 'variable') {
                    return this.inferTypeFromValue(astNode.value, funcStr, ast);
                }
            }
        }
        
        return 'f32';  // For other types
    }

    flattenStrings(arr) {
        if(!arr) throw new Error(arr);
        const callback = (item) => {
            if (item.startsWith('[') && item.endsWith(']')) {
                return item.slice(1, -1).split(',').map(s => s.trim());
            }
            return item;
        }
        return arr.reduce((acc, value, index, array) => {
            return acc.concat(callback(value, index, array));
        }, []);
    }

    generateDataStructures(funcStr, ast) {
        let code = '//Bindings (data passed to/from CPU) \n';
        // Extract all returned variables from the function string
        // const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
        // let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;

        // returnedVars = this.flattenStrings(returnedVars);

        // Capture all nested functions
        const functionRegex = /function (\w+)\(([^()]*|\((?:[^()]*|\([^()]*\))*\))*\) \{([\s\S]*?)\}/g;
        let modifiedStr = this.fstr;

        let match;
        while ((match = functionRegex.exec(this.fstr)) !== null) {
            // Replace the content of the nested function with a placeholder
            modifiedStr = modifiedStr.replace(match[3], 'PLACEHOLDER');
        }

        // Now, search for return statements in the modified string
        const returnMatches = modifiedStr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
        let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;
        returnedVars = this.flattenStrings(returnedVars);


        let uniformsStruct = 'struct UniformsStruct {\n'; // Start the UniformsStruct
        let hasUniforms = false; // Flag to check if there are any uniforms

        this.params = [];

        let bindingIncr = 0;

        ast.forEach((node, i) => {
            if(returnedVars.includes(node.name)) node.isInput = true; //catch extra returned variables not in the explicit input buffers (data structures generated by webgpu)
            if(node.isInput) {
                if (node.type === 'array') {
                    this.bindings.push(node.name);
                    const elementType = this.inferTypeFromValue(node.value.split(',')[0], funcStr, ast);
                    
                    node.type = elementType; // Use the inferred type directly
                    this.params.push(node);
                    code += `struct ${capitalizeFirstLetter(node.name)}Struct {\n    values: ${elementType}\n};\n\n`;
                    code += `@group(0) @binding(${bindingIncr})\n`;
                    
                    if (!returnedVars || returnedVars?.includes(node.name)) {
                        code += `var<storage, read_write> ${node.name}: ${capitalizeFirstLetter(node.name)}Struct;\n\n`;
                    } else {
                        code += `var<storage, read> ${node.name}: ${capitalizeFirstLetter(node.name)}Struct;\n\n`;
                    }
                    bindingIncr++;
                }
                else if (node.isUniform) {
                    if(!hasUniforms) {
                        hasUniforms = bindingIncr; // Set the flag to the index
                        bindingIncr++;
                    }
                    this.bindings.push(node.name);
                    const uniformType = this.inferTypeFromValue(node.value, funcStr, ast);
                    node.type = uniformType;
                    this.params.push(node);
                    uniformsStruct += `    ${node.name}: ${uniformType},\n`; // Add the uniform to the UniformsStruct
                }
            }
        });

        uniformsStruct += '};\n\n'; // Close the UniformsStruct

        if (hasUniforms) { // If there are any uniforms, add the UniformsStruct and its binding to the code
            code += uniformsStruct;
            code += `@group(0) @binding(${hasUniforms}) var<uniform> uniforms: UniformsStruct;\n\n`;
        }

        return code;
    }


    extractAndTransposeInnerFunctions = (body, extract=true) => {
        
        const functionRegex = /function (\w+)\(([^()]*|\((?:[^()]*|\([^()]*\))*\))*\) \{([\s\S]*?)\}/g;

        let match;
        let extractedFunctions = '';
        
        while ((match = functionRegex.exec(body)) !== null) {

            const functionHead = match[0];

            let paramString = functionHead.substring(functionHead.indexOf('(') + 1, functionHead.lastIndexOf(')'));

            let outputParam;
            let params = this.splitIgnoringBrackets(paramString).map((p) => { 
                let split = p.split('=');
                let vname = split[0];
        
                let inferredType = this.inferTypeFromValue(split[1], body, this.ast);
                if(!outputParam) outputParam = inferredType;
                return vname+': '+inferredType;
            });

            const funcName = match[1];
            const funcBody = match[3];

            // Transpose the function body
            const transposedBody = this.transposeBody(funcBody, funcBody, null, true); // Assuming AST is not used in your current implementation

            extractedFunctions += `fn ${funcName}(${params}) -> ${outputParam} {\n${transposedBody}\n}\n\n`;
        }

        // Remove the inner functions from the main body
        if(extract) body = body.replace(functionRegex, '');

        return { body, extractedFunctions };
    }

    generateMainFunctionWorkGroup(funcStr, ast, size=256) {
        let code = '//Main function call\n//globalId tells us what x,y,z thread we are on\n';
        
        if(this.functions) {
            this.functions.forEach((f) => {
                let result = this.extractAndTransposeInnerFunctions(f.toString(), false);
                if(result.extractedFunctions) code += result.extractedFunctions;
            })
        }

        // Extract inner functions and transpose them
        const { body: mainBody, extractedFunctions } = this.extractAndTransposeInnerFunctions(funcStr.match(/{([\s\S]+)}/)[1], true);
        
        // Prepend the transposed inner functions to the main function
        code += extractedFunctions;

        // Generate function signature
        code += '@compute @workgroup_size('+size+')\n';
        code += `fn main(\n  @builtin(global_invocation_id) globalId: vec3<u32>`;

        code += '\n) {\n';

        // Transpose the main body
        code += this.transposeBody(mainBody, funcStr, ast);

        code += '}\n';
        return code;
    }


    transposeBody = (body, funcStr, ast, returns = false) => {
        let code = '';

        // Capture commented lines and replace with a placeholder
        const commentPlaceholders = {};
        let placeholderIndex = 0;
        body = body.replace(/\/\/.*$/gm, (match) => {
            const placeholder = `__COMMENT_PLACEHOLDER_${placeholderIndex}__`;
            commentPlaceholders[placeholder] = match;
            placeholderIndex++;
            return placeholder;
        });

        // Replace common patterns
        code = body.replace(/for \(let (\w+) = (\w+); \1 < (\w+); \1\+\+\)/g, 'for (var $1 = $2u; $1 < $3; $1 = $1 + 1u)');
        code = code.replace(/const (\w+) = (\w+).length;/g, 'let $1 = arrayLength(&$2.values);');
        code = code.replace(/const (\w+) = globalId.(\w+);/g, 'let $1 = globalId.$2;');
        code = code.replace(/const/g, 'let');

        const vecMatDeclarationRegex = /let (\w+) = (vec\d+|mat\d+)/g;
        code = code.replace(vecMatDeclarationRegex, 'var $1 = $2');
        // Handle array access
        code = code.replace(/(\w+)\[([\w\s+\-*\/]+)\]/g, '$1.values[$2]');
        
        // Handle array length
        code = code.replace(/(\w+).length/g, 'arrayLength(&$1.values)');

    
        // Handle mathematical operations
        code = replaceJSFunctions(code, replacements);

        // Handle vector and matrix creation
        const vecMatCreationRegex = /(vec(\d+)|mat(\d+))\(([^)]+)\)/g;
        code = code.replace(vecMatCreationRegex, (match, type, vecSize, matSize, args) => {
            // Split the arguments and check if any of them contain a decimal point
            const argArray = args.split(',').map(arg => arg);
            const hasDecimal = argArray.some(arg => arg.includes('.'));
            
            // If any argument has a decimal, it's a float, otherwise it's an integer
            const inferredType = hasDecimal ? 'f32' : 'i32';
            
            return `${type}<${inferredType}>(${argArray.join(', ')})`;
        });

        
        this.params.forEach((param) => {
            if(param.isUniform) {
                const regex = new RegExp(`(?<![a-zA-Z0-9])${param.name}(?![a-zA-Z0-9])`, 'g');
                code = code.replace(regex, `uniforms.${param.name}`);
            }
        });

        // Replace placeholders with their corresponding comments
        for (const [placeholder, comment] of Object.entries(commentPlaceholders)) {
            code = code.replace(placeholder, comment);
        }
        
        // Ensure lines not ending with a semicolon or open bracket have a semicolon appended. Not sure if this is 
        code = code.replace(/^(.*[^;\s\{\[\(\,\>\}])(\s*\/\/.*)$/gm, '$1;$2');
        code = code.replace(/^(.*[^;\s\{\[\(\,\>\}])(?!\s*\/\/)(?=\s*$)/gm, '$1;');
        //trim off some cases for inserting semicolons wrong
        code = code.replace(/(\/\/[^\n]*);/gm, '$1'); //trim off semicolons after comments
        code = code.replace(/\);\s*(\n\s*)\)/gm, ')$1)'); //trim off semicolons between end parentheses 

        if(!returns) code = code.replace(/(return [^;]+;)/g, '//$1');
        this.mainBody = code;

        return code;
    }

    addFunction = (func) => {
        if(!this.functions) this.functions = [];
        this.functions.push(func);
        let result = this.convertToWebGPU(); 
        return this.init(result.shader, result.ast);
    }

    convertToWebGPU(func=this.fstr) {
        const funcStr = typeof func === 'string' ? func : func.toString();
        this.fstr = funcStr;
        const tokens = this.tokenize(funcStr);
        const ast = this.parse(tokens);
        let webGPUCode = this.generateDataStructures(funcStr, ast);
        webGPUCode += '\n' + this.generateMainFunctionWorkGroup(funcStr, ast); // Pass funcStr as the first argument
        return {shader:webGPUCode, ast};
    }
}

function capitalizeFirstLetter(string) {
    return string.charAt(0).toUpperCase() + string.slice(1);
}

function replaceJSFunctions(code, replacements) {
    for (let [jsFunc, shaderFunc] of Object.entries(replacements)) {
        const regex = new RegExp(jsFunc.replace('.', '\\.'), 'g'); // Escape dots for regex
        code = code.replace(regex, shaderFunc);
    }
    return code;
}

// Usage:
const replacements = {
    'Math.PI': `${Math.PI}`,
    'Math.E':  `${Math.E}`,
    'Math.abs': 'abs',
    'Math.acos': 'acos',
    'Math.asin': 'asin',
    'Math.atan': 'atan',
    'Math.atan2': 'atan2', // Note: Shader might handle atan2 differently, ensure compatibility
    'Math.ceil': 'ceil',
    'Math.cos': 'cos',
    'Math.exp': 'exp',
    'Math.floor': 'floor',
    'Math.log': 'log',
    'Math.max': 'max',
    'Math.min': 'min',
    'Math.pow': 'pow',
    'Math.round': 'round',
    'Math.sin': 'sin',
    'Math.sqrt': 'sqrt',
    'Math.tan': 'tan',
    // ... add more replacements as needed
};

const wgslTypeSizes32 = {
    'i32': { alignment: 4, size: 4 },
    'u32': { alignment: 4, size: 4 },
    'f32': { alignment: 4, size: 4 },
    'atomic': { alignment: 4, size: 4 },
    'vec2<f32>': { alignment: 8, size: 8 },
    'vec2<i32>': { alignment: 8, size: 8 },
    'vec2<u32>': { alignment: 8, size: 8 },
    'vec3<f32>': { alignment: 16, size: 12 },
    'vec3<i32>': { alignment: 16, size: 12 },
    'vec3<u32>': { alignment: 16, size: 12 },
    'vec4<f32>': { alignment: 16, size: 16 },
    'vec4<i32>': { alignment: 16, size: 16 },
    'vec4<u32>': { alignment: 16, size: 16 },
    'mat2x2<f32>': { alignment: 8, size: 16 },
    'mat2x2<i32>': { alignment: 8, size: 16 },
    'mat2x2<u32>': { alignment: 8, size: 16 },
    'mat3x2<f32>': { alignment: 8, size: 24 },
    'mat3x2<i32>': { alignment: 8, size: 24 },
    'mat3x2<u32>': { alignment: 8, size: 24 },
    'mat4x2<f32>': { alignment: 8, size: 32 },
    'mat4x2<i32>': { alignment: 8, size: 32 },
    'mat4x2<u32>': { alignment: 8, size: 32 },
    'mat2x3<f32>': { alignment: 16, size: 32 },
    'mat2x3<i32>': { alignment: 16, size: 32 },
    'mat2x3<u32>': { alignment: 16, size: 32 },
    'mat3x3<f32>': { alignment: 16, size: 48 },
    'mat3x3<i32>': { alignment: 16, size: 48 },
    'mat3x3<u32>': { alignment: 16, size: 48 },
    'mat4x3<f32>': { alignment: 16, size: 64 },
    'mat4x3<i32>': { alignment: 16, size: 64 },
    'mat4x3<u32>': { alignment: 16, size: 64 },
    'mat2x4<f32>': { alignment: 16, size: 32 },
    'mat2x4<i32>': { alignment: 16, size: 32 },
    'mat2x4<u32>': { alignment: 16, size: 32 },
    'mat3x4<f32>': { alignment: 16, size: 48 },
    'mat3x4<i32>': { alignment: 16, size: 48 },
    'mat3x4<u32>': { alignment: 16, size: 48 },
    'mat4x4<f32>': { alignment: 16, size: 64 },
    'mat4x4<i32>': { alignment: 16, size: 64 },
    'mat4x4<u32>': { alignment: 16, size: 64 }
};

const wgslTypeSizes16 = {
    'i16': { alignment: 2, size: 2 },
    'u16': { alignment: 2, size: 2 },
    'f16': { alignment: 2, size: 2 },
    'vec2<f16>': { alignment: 4, size: 4 },
    'vec2<i16>': { alignment: 4, size: 4 },
    'vec2<u16>': { alignment: 4, size: 4 },
    'vec3<f16>': { alignment: 8, size: 6 },
    'vec3<i16>': { alignment: 8, size: 6 },
    'vec3<u16>': { alignment: 8, size: 6 },
    'vec4<f16>': { alignment: 8, size: 8 },
    'vec4<i16>': { alignment: 8, size: 8 },
    'vec4<u16>': { alignment: 8, size: 8 },
    'mat2x2<f16>': { alignment: 4, size: 8 },
    'mat2x2<i16>': { alignment: 4, size: 8 },
    'mat2x2<u16>': { alignment: 4, size: 8 },
    'mat3x2<f16>': { alignment: 4, size: 12 },
    'mat3x2<i16>': { alignment: 4, size: 12 },
    'mat3x2<u16>': { alignment: 4, size: 12 },
    'mat4x2<f16>': { alignment: 4, size: 16 },
    'mat4x2<i16>': { alignment: 4, size: 16 },
    'mat4x2<u16>': { alignment: 4, size: 16 },
    'mat2x3<f16>': { alignment: 8, size: 16 },
    'mat2x3<i16>': { alignment: 8, size: 16 },
    'mat2x3<u16>': { alignment: 8, size: 16 },
    'mat3x3<f16>': { alignment: 8, size: 24 },
    'mat3x3<i16>': { alignment: 8, size: 24 },
    'mat3x3<u16>': { alignment: 8, size: 24 },
    'mat4x3<f16>': { alignment: 8, size: 32 },
    'mat4x3<i16>': { alignment: 8, size: 32 },
    'mat4x3<u16>': { alignment: 8, size: 32 },
    'mat2x4<f16>': { alignment: 8, size: 16 },
    'mat2x4<i16>': { alignment: 8, size: 16 },
    'mat2x4<u16>': { alignment: 8, size: 16 },
    'mat3x4<f16>': { alignment: 8, size: 24 },
    'mat3x4<i16>': { alignment: 8, size: 24 },
    'mat3x4<u16>': { alignment: 8, size: 24 },
    'mat4x4<f16>': { alignment: 8, size: 32 },
    'mat4x4<i16>': { alignment: 8, size: 32 },
    'mat4x4<u16>': { alignment: 8, size: 32 }
};


const wgslTypeSizes = Object.assign({}, wgslTypeSizes16, wgslTypeSizes32);




function dft(
    inputData = [], 
    outputData = [], 
    //dummy inputs
    outp3 = mat2x2(vec2(1.0,1.0),vec2(1.0,1.0)), 
    outp4 = 4,
    outp5 = vec3(1,2,3),
    outp6 = [vec2(1.0,1.0)]
) {

    function add(a=vec2(0.0,0.0),b=vec2(0.0,0.0)) { //transpiled out of main body
        return a + b;
    }

    const N = inputData.length;
    const k = globalId.x;
    var sum = vec2(0.0, 0.0);

    var sum2 = add(sum,sum);

    const b = 3 + outp4;

    var M = mat4x4(
        vec4(1.0,0.0,0.0,0.0),
        vec4(0.0,1.0,0.0,0.0),
        vec4(0.0,0.0,1.0,0.0),
        vec4(0.0,0.0,0.0,1.0)
    ); //identity matrix

    let D = M + M;

    var Z = outp3 * mat2x2(vec2(4.0,-1.0),vec2(3.0,2.0));

    var Zz = outp5 + vec3(4,5,6);

    for (let n = 0; n < N; n++) {
        const phase = 2.0 * Math.PI * f32(k) * f32(n) / f32(N);
        sum = sum + vec2(
            inputData[n] * Math.cos(phase),
            -inputData[n] * Math.sin(phase)
        );
    }

    let v = 2

    const outputIndex = k * 2 //use strict
    if (outputIndex + 1 < outputData.length) {
        outputData[outputIndex] = sum.x;
        outputData[outputIndex + 1] = sum.y;
    }

    
    return [inputData, outputData]; //returning an array of inputs lets us return several buffer promises
    //return outputData;
    //return outp4; //we can also return the uniform buffer though it is immutable so it's pointless
}

//explicit return statements will define only that variable as an output (i.e. a mutable read_write buffer)


const parser = new WebGPUjs();
const webGPUCode = parser.convertToWebGPU(dft);


//console.log(webGPUCode);
document.body.style.backgroundColor = 'black';
document.body.style.color = 'white';

document.body.insertAdjacentHTML('afterbegin', `
    
    <span style="position:absolute; left:0px;">
        Before (edit me!):<br>
        <textarea id="t2" style="width:50vw; background-color:#303000; color:lightblue; height:100vh;">${dft.toString()}</textarea>
    </span>
    <span style="position:absolute; left:50vw;">
        After:<br>
        <textarea id="t1" style="width:50vw; background-color:#000020; color:lightblue; height:100vh;">${webGPUCode.shader}</textarea>
    </span>
`);

function parseFunction() {
    const fstr = document.getElementById('t2').value;
    const webGPUCode = parser.convertToWebGPU(fstr);
    document.getElementById('t1').value = webGPUCode.shader;
}

document.getElementById('t2').oninput = () => {
    parseFunction();
}

WebGPUjs.createPipeline(dft).then(pipeline => {
    // Create some sample input data
    const len = 256;
    const inputData = new Float32Array(len).fill(1.0); // Example data
    const outputData = new Float32Array(len*2).fill(0);

    // Run the process method to execute the shader
    pipeline.process(inputData,outputData,[1.0,2.0,3.0,4.0], 1, [1,2,3], [1,2]).then(result => {
        console.log(result); // Log the output
        if(result[2]?.buffer) { //we returned the uniform buffer for some reason, double check alignments
            console.log(
                new DataView(result[2].buffer).getInt32(16,true) // the int32 is still correctly encoded
            )
        }

        pipeline.addFunction(function mul(a=vec2(2.0),b=vec2(2.0)) { return a * b; }).then((p) => {
            document.getElementById('t1').value = pipeline.shader;
        });

    });
});


const dftReference = `
                
struct InputData {
    values : array<f32>
}

struct OutputData {
    values: array<f32>
}

@group(0) @binding(0)
var<storage, read> inputData: InputData;

@group(0) @binding(1)
var<storage, read_write> outputData: OutputData;

@compute @workgroup_size(256)
fn main(
    @builtin(global_invocation_id) globalId: vec3<u32>
) {
    let N = arrayLength(&inputData.values);
    let k = globalId.x;
    var sum = vec2<f32>(0.0, 0.0);

    for (var n = 0u; n < N; n = n + 1u) {
        let phase = 2.0 * 3.14159265359 * f32(k) * f32(n) / f32(N);
        sum = sum + vec2<f32>(
            inputData.values[n] * cos(phase),
            -inputData.values[n] * sin(phase)
        );
    }

    let outputIndex = k * 2;
    if (outputIndex + 1 < arrayLength(&outputData.values)) {
        outputData.values[outputIndex] = sum.x;
        outputData.values[outputIndex + 1] = sum.y;
    }
}

`
        </script>


<script>

</script>
    </body>
</html>

@joshbrew
Copy link
Author

joshbrew commented Aug 28, 2023

Here's a repo. Compute shaders work great just getting the rendering and pipeline chaining working now (could never get that to work in GPUjs). It's pretty much a stock renderer just requires the special javascript-ish shader function input format

https://github.com/joshbrew/webgpujs

@jacobbogers
Copy link

@joshbrew yes, you have far far more control over webgpu then webgl, webgpu is fully async, so you can keep feeding it data and execution instructions keeping it fully occupied at all times

@joshbrew
Copy link
Author

joshbrew commented Dec 22, 2023

@jacobbogers I've made a ton of progress on my webgpujs thing, there is another project much closer to gpujs here: https://github.com/AmesingFlank/taichi.js

I'm otherwise stuck getting textures to render correctly before I can polish up my library finally. (I'd pay someone at this point)

@jacobbogers
Copy link

just normal buffers for compute shaders, why r u using textures?

@jacobbogers
Copy link

That tatchi stuff looks crazy good

@joshbrew
Copy link
Author

I have an entire process to transpile compute, vertex, and fragment and then a WIP method to combine bindings as you chain shaders so I can create as many shaders as I want that can share data structures. So that includes textures or especially storage textures once I solve the basic texture issue I am having. This is not really documented yet till I can fix the bugs and finalize the workflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants