#![allow(irrefutable_let_patterns)]

use blade_graphics as gpu;
use std::{num::NonZeroU32, slice};

struct Globals {
    modulator: [f32; 4],
    demodulator: gpu::BufferPiece,
    input: gpu::TextureView,
    output: gpu::TextureView,
}

// Using a manual implementation of the trait
// to show what's generated by the derive macro.
impl gpu::ShaderData for Globals {
    fn layout() -> gpu::ShaderDataLayout {
        gpu::ShaderDataLayout {
            bindings: vec![
                ("modulator", gpu::ShaderBinding::Plain { size: 16 }),
                ("demodulator", gpu::ShaderBinding::Buffer),
                ("input", gpu::ShaderBinding::Texture),
                ("output", gpu::ShaderBinding::Texture),
            ],
        }
    }
    fn fill(&self, mut ctx: gpu::PipelineContext) {
        use gpu::ShaderBindable as _;
        self.modulator.bind_to(&mut ctx, 0);
        self.demodulator.bind_to(&mut ctx, 1);
        self.input.bind_to(&mut ctx, 2);
        self.output.bind_to(&mut ctx, 3);
    }
}

fn main() {
    env_logger::init();
    let context = unsafe { gpu::Context::init(gpu::ContextDesc::default()).unwrap() };

    let global_layout = <Globals as gpu::ShaderData>::layout();
    let shader_source = std::fs::read_to_string("examples/mini/shader.wgsl").unwrap();
    let shader = context.create_shader(gpu::ShaderDesc {
        source: &shader_source,
    });

    let pipeline = context.create_compute_pipeline(gpu::ComputePipelineDesc {
        name: "main",
        data_layouts: &[&global_layout],
        compute: shader.at("main"),
    });

    let extent = gpu::Extent {
        width: 16,
        height: 16,
        depth: 1,
    };
    let mip_level_count = extent.max_mip_levels();
    let texture = context.create_texture(gpu::TextureDesc {
        name: "input",
        format: gpu::TextureFormat::Rgba8Unorm,
        size: extent,
        dimension: gpu::TextureDimension::D2,
        array_layer_count: 1,
        mip_level_count,
        usage: gpu::TextureUsage::RESOURCE | gpu::TextureUsage::STORAGE | gpu::TextureUsage::COPY,
        sample_count: 1,
        external: None,
    });
    let views = (0..mip_level_count)
        .map(|i| {
            context.create_texture_view(
                texture,
                gpu::TextureViewDesc {
                    name: &format!("mip-{}", i),
                    format: gpu::TextureFormat::Rgba8Unorm,
                    dimension: gpu::ViewDimension::D2,
                    subresources: &gpu::TextureSubresources {
                        base_mip_level: i,
                        mip_level_count: NonZeroU32::new(1),
                        base_array_layer: 0,
                        array_layer_count: None,
                    },
                },
            )
        })
        .collect::<Vec<_>>();

    let result_buffer = context.create_buffer(gpu::BufferDesc {
        name: "result",
        size: 4,
        memory: gpu::Memory::Shared,
    });

    let upload_buffer = context.create_buffer(gpu::BufferDesc {
        name: "staging",
        size: (extent.width * extent.height) as u64 * 4,
        memory: gpu::Memory::Upload,
    });
    {
        let data = unsafe {
            slice::from_raw_parts_mut(
                upload_buffer.data() as *mut u32,
                (extent.width * extent.height) as usize,
            )
        };
        for y in 0..extent.height {
            for x in 0..extent.width {
                data[(y * extent.width + x) as usize] = y * x;
            }
        }
    }
    let demodulator_buf = context.create_buffer(gpu::BufferDesc {
        name: "demodulator",
        size: 4,
        memory: gpu::Memory::Shared,
    });

    let mut command_encoder = context.create_command_encoder(gpu::CommandEncoderDesc {
        name: "main",
        buffer_count: 1,
    });
    command_encoder.start();
    command_encoder.init_texture(texture);

    if let mut transfer = command_encoder.transfer("gen-mips") {
        transfer.copy_buffer_to_texture(
            upload_buffer.into(),
            extent.width * 4,
            texture.into(),
            extent,
        );
    }
    for i in 1..mip_level_count {
        if let mut compute = command_encoder.compute("generate mips") {
            if let mut pc = compute.with(&pipeline) {
                let groups = pipeline.get_dispatch_for(extent.at_mip_level(i));
                pc.bind(
                    0,
                    &Globals {
                        modulator: if i == 1 {
                            [0.2, 0.4, 0.3, 0.0]
                        } else {
                            [1.0; 4]
                        },
                        demodulator: demodulator_buf.at(0),
                        input: views[i as usize - 1],
                        output: views[i as usize],
                    },
                );
                pc.dispatch(groups);
            }
        }
    }
    if let mut tranfer = command_encoder.transfer("init 1x2 texture") {
        tranfer.copy_texture_to_buffer(
            gpu::TexturePiece {
                texture,
                mip_level: mip_level_count - 1,
                array_layer: 0,
                origin: Default::default(),
            },
            result_buffer.into(),
            4,
            gpu::Extent {
                width: 1,
                height: 1,
                depth: 1,
            },
        );
    }
    let sync_point = context.submit(&mut command_encoder);

    let ok = context.wait_for(&sync_point, 1000);
    assert!(ok);
    let answer = unsafe { *(result_buffer.data() as *mut u32) };
    println!("Output: 0x{:x}", answer);

    context.destroy_command_encoder(&mut command_encoder);
    context.destroy_buffer(result_buffer);
    context.destroy_buffer(upload_buffer);
    context.destroy_buffer(demodulator_buf);
    for view in views {
        context.destroy_texture_view(view);
    }
    context.destroy_texture(texture);
}
