You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.4 KiB

use std::io::{self, BufRead, Write};
use anyhow::Context;
use vulkano::{
buffer::{BufferUsage, CpuAccessibleBuffer},
command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, PrimaryCommandBuffer},
descriptor_set::PersistentDescriptorSet,
device::{physical::PhysicalDevice, Device, DeviceExtensions, Features},
instance::{Instance, InstanceExtensions},
pipeline::{ComputePipeline, Pipeline, PipelineBindPoint},
sync::GpuFuture,
Version,
};
fn main() -> anyhow::Result<()> {
let instance = Instance::new(None, Version::V1_5, &InstanceExtensions::none(), None)?;
let dev_extensions = DeviceExtensions {
khr_storage_buffer_storage_class: true,
..DeviceExtensions::none()
};
let phy_dev = PhysicalDevice::enumerate(&instance)
.find(|&p| p.supported_extensions().is_superset_of(&dev_extensions))
.context("Couldn't find a suitable GPU!")?;
let (dev, mut queues) = Device::new(
phy_dev,
&Features::none(),
&phy_dev.required_extensions().union(&dev_extensions),
std::iter::once((
phy_dev
.queue_families()
.next()
.context("GPU has no queue families!")?,
0.5,
)),
)?;
let queue = queues
.next()
.context("Queue family has 0 queues?! Where'd you get that GPU from?")?;
#[allow(deprecated, clippy::needless_question_mark)]
mod cs {
vulkano_shaders::shader! {
ty: "compute",
src: "
#version 450
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer Data {
int data[];
} data;
void main() {
data.data[2] = data.data[0] + data.data[1];
}
"
}
}
let shader = cs::load(dev.clone())?;
let pipeline = ComputePipeline::new(
dev.clone(),
shader
.entry_point("main")
.context("Couldn't get shader entry point")?,
&(),
None,
|_| (),
)?;
let stdin = io::stdin();
let mut stdin = stdin.lock();
let mut buf = String::new();
print!("Enter number 1 > ");
io::stdout().flush()?;
stdin.read_line(&mut buf)?;
let n1 = buf.trim().parse::<i32>()?;
buf.clear();
print!("Enter number 2 > ");
io::stdout().flush()?;
stdin.read_line(&mut buf)?;
let n2 = buf.trim().parse::<i32>()?;
drop(stdin);
let buf = CpuAccessibleBuffer::from_data(
dev.clone(),
BufferUsage::storage_buffer(),
false,
[n1, n2, 0],
)?;
let mut builder =
PersistentDescriptorSet::start(pipeline.layout().descriptor_set_layouts()[0].clone());
builder.add_buffer(buf.clone())?;
let set = builder.build()?;
let mut builder =
AutoCommandBufferBuilder::primary(dev, queue.family(), CommandBufferUsage::OneTimeSubmit)?;
builder
.bind_pipeline_compute(pipeline.clone())
.bind_descriptor_sets(
PipelineBindPoint::Compute,
pipeline.layout().clone(),
0,
set,
)
.dispatch([1, 1, 1])?;
let cmd_buf = builder.build()?;
cmd_buf
.execute(queue)?
.then_signal_fence_and_flush()?
.wait(None)?;
let sum = buf.read()?[2];
println!("Sum: {}", sum);
Ok(())
}