openvm_cuda_builder/
lib.rs1use std::{env, path::Path, process::Command};
2
3#[derive(Debug, Clone)]
5pub struct CudaBuilder {
6 include_paths: Vec<String>,
7 source_files: Vec<String>,
8 watch_paths: Vec<String>,
9 watch_globs: Vec<String>,
10 library_name: String,
11 cuda_arch: Vec<String>,
12 cuda_opt_level: Option<String>,
13 custom_flags: Vec<String>,
14 link_libraries: Vec<String>,
15 link_search_paths: Vec<String>,
16}
17
18impl Default for CudaBuilder {
19 fn default() -> Self {
20 let mut link_search_paths = Vec::new();
21 if let Ok(cuda_lib_dir) = env::var("CUDA_LIB_DIR") {
22 link_search_paths.push(cuda_lib_dir);
23 } else {
24 link_search_paths.push("/usr/local/cuda/lib64".to_string());
25 }
26
27 Self {
28 include_paths: Vec::new(),
29 source_files: Vec::new(),
30 watch_paths: vec!["build.rs".to_string()],
31 watch_globs: Vec::new(),
32 library_name: String::new(),
33 cuda_arch: Vec::new(),
34 cuda_opt_level: None,
35 custom_flags: vec![
36 "--std=c++17".to_string(),
37 "--expt-relaxed-constexpr".to_string(),
38 "-Xfatbin=-compress-all".to_string(),
39 "--default-stream=per-thread".to_string(),
40 ],
41 link_libraries: vec!["cudart".to_string(), "cuda".to_string()],
42 link_search_paths,
43 }
44 }
45}
46
47impl CudaBuilder {
48 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn library_name(mut self, name: &str) -> Self {
55 self.library_name = name.to_string();
56 self
57 }
58
59 pub fn include<P: AsRef<Path>>(mut self, path: P) -> Self {
61 let path_str = path.as_ref().to_string_lossy().to_string();
62 self.include_paths.push(path_str.clone());
63 self.watch_paths.push(path_str);
64 self
65 }
66
67 pub fn include_from_dep(mut self, dep_env_var: &str) -> Self {
69 if let Ok(path) = env::var(dep_env_var) {
70 self.include_paths.push(path);
71 }
72 self
73 }
74
75 pub fn file<P: AsRef<Path>>(mut self, path: P) -> Self {
77 let path_str = path.as_ref().to_string_lossy().to_string();
78 self.source_files.push(path_str.clone());
79 self.watch_paths.push(path_str);
80 self
81 }
82
83 pub fn files<P: AsRef<Path>, I: IntoIterator<Item = P>>(mut self, paths: I) -> Self {
85 for path in paths {
86 let path_str = path.as_ref().to_string_lossy().to_string();
87 self.source_files.push(path_str.clone());
88 self.watch_paths.push(path_str);
89 }
90 self
91 }
92
93 pub fn files_from_glob(mut self, pattern: &str) -> Self {
95 self.watch_globs.push(pattern.to_string());
96 for path in glob::glob(pattern).expect("Invalid glob pattern").flatten() {
97 if path.is_file() && path.extension().is_some_and(|ext| ext == "cu") {
98 self.source_files.push(path.to_string_lossy().to_string());
99 }
100 }
101 self
102 }
103
104 pub fn watch<P: AsRef<Path>>(mut self, path: P) -> Self {
106 self.watch_paths
107 .push(path.as_ref().to_string_lossy().to_string());
108 self
109 }
110
111 pub fn watch_glob(mut self, pattern: &str) -> Self {
113 self.watch_globs.push(pattern.to_string());
114 self
115 }
116
117 pub fn cuda_arch(mut self, arch: &str) -> Self {
119 self.cuda_arch = vec![arch.to_string()];
120 self
121 }
122
123 pub fn cuda_archs(mut self, archs: Vec<&str>) -> Self {
125 self.cuda_arch = archs.iter().map(|s| s.to_string()).collect();
126 self
127 }
128
129 pub fn cuda_opt_level(mut self, level: u8) -> Self {
131 self.cuda_opt_level = Some(level.to_string());
132 self
133 }
134
135 pub fn flag(mut self, flag: &str) -> Self {
137 self.custom_flags.push(flag.to_string());
138 self
139 }
140
141 pub fn link_lib(mut self, lib: &str) -> Self {
143 self.link_libraries.push(lib.to_string());
144 self
145 }
146
147 pub fn link_search<P: AsRef<Path>>(mut self, path: P) -> Self {
149 self.link_search_paths
150 .push(path.as_ref().to_string_lossy().to_string());
151 self
152 }
153
154 pub fn build(self) {
156 self.validate();
158
159 self.setup_rerun_conditions();
161
162 let cuda_archs = self.get_cuda_arch();
164
165 let mut builder = cc::Build::new();
167 builder.cuda(true);
168
169 self.handle_debug_shortcuts(&mut builder);
171
172 let cuda_opt_level = self.get_cuda_opt_level();
174
175 for include in &self.include_paths {
177 builder.include(include);
178 }
179
180 if let Ok(cuda_path) = env::var("CUDA_PATH") {
182 builder.include(format!("{}/include", cuda_path));
183 }
184
185 for flag in &self.custom_flags {
187 builder.flag(flag);
188 }
189
190 for arch in &cuda_archs {
192 builder
193 .flag("-gencode")
194 .flag(format!("arch=compute_{},code=sm_{}", arch, arch));
195 }
196
197 if let Some(max_arch) = cuda_archs.iter().max() {
200 builder.flag("-gencode").flag(format!(
201 "arch=compute_{},code=compute_{}",
202 max_arch, max_arch
203 ));
204 }
205
206 builder.flag(nvcc_parallel_jobs());
208
209 if cuda_opt_level == "0" {
211 builder.debug(true).flag("-O0");
212 } else {
213 builder
214 .debug(false)
215 .flag(format!("--ptxas-options=-O{}", cuda_opt_level));
216 }
217
218 for file in &self.source_files {
220 builder.file(file);
221 }
222
223 builder.compile(&self.library_name);
225 }
226
227 fn validate(&self) {
229 if self.library_name.is_empty() {
230 panic!(
231 "Library name must be set using .library_name(\"name\") before calling .build()"
232 );
233 }
234
235 if self.source_files.is_empty() {
236 panic!("At least one source file must be added using .file() or .files() before calling .build()");
237 }
238
239 for file in &self.source_files {
241 if !Path::new(file).exists() {
242 eprintln!("cargo:warning=CUDA source file does not exist: {}", file);
243 }
244 }
245
246 for include in &self.include_paths {
248 if !Path::new(include).exists() {
249 eprintln!("cargo:warning=Include path does not exist: {}", include);
250 }
251 }
252 }
253
254 pub fn emit_link_directives(&self) {
255 for path in &self.link_search_paths {
256 println!("cargo:rustc-link-search=native={}", path);
257 }
258 for lib in &self.link_libraries {
259 println!("cargo:rustc-link-lib={}", lib);
260 }
261 }
262
263 fn setup_rerun_conditions(&self) {
264 println!("cargo:rerun-if-env-changed=CUDA_ARCH");
266 println!("cargo:rerun-if-env-changed=CUDA_OPT_LEVEL");
267 println!("cargo:rerun-if-env-changed=CUDA_DEBUG");
268 println!("cargo:rerun-if-env-changed=NVCC_THREADS");
269
270 for path in &self.watch_paths {
272 println!("cargo:rerun-if-changed={}", path);
273 }
274
275 for pattern in &self.watch_globs {
277 watch_glob(pattern);
278 }
279 }
280
281 fn get_cuda_arch(&self) -> Vec<String> {
282 if !self.cuda_arch.is_empty() {
283 return self.cuda_arch.clone();
284 }
285
286 if let Ok(env_archs) = env::var("CUDA_ARCH") {
288 return env_archs
289 .split(',')
290 .map(|s| s.trim().to_string())
291 .filter(|s| !s.is_empty())
292 .collect();
293 }
294
295 vec![detect_cuda_arch()]
297 }
298
299 fn get_cuda_opt_level(&self) -> String {
300 if let Some(level) = &self.cuda_opt_level {
301 return level.clone();
302 }
303
304 env::var("CUDA_OPT_LEVEL").unwrap_or_else(|_| "3".to_string())
305 }
306
307 fn handle_debug_shortcuts(&self, builder: &mut cc::Build) {
308 if env::var("CUDA_DEBUG").map(|v| v == "1").unwrap_or(false) {
309 env::set_var("CUDA_OPT_LEVEL", "0");
310 env::set_var("CUDA_LAUNCH_BLOCKING", "1");
311 env::set_var("RUST_BACKTRACE", "full");
312 env::set_var("CUDA_ENABLE_COREDUMP_ON_EXCEPTION", "1");
313 env::set_var("CUDA_DEVICE_WAITS_ON_EXCEPTION", "1");
314
315 println!("cargo:warning=CUDA_DEBUG=1 → Enabling comprehensive debugging:");
316 println!("cargo:warning= → CUDA_OPT_LEVEL=0 (no optimization)");
317 println!("cargo:warning= → CUDA_LAUNCH_BLOCKING=1 (synchronous kernels)");
318 println!("cargo:warning= → Line info and device debug symbols enabled");
319 println!("cargo:warning= → CUDA_DEBUG macro defined for preprocessor");
320
321 builder.flag("-G"); builder.flag("-Xcompiler=-fno-omit-frame-pointer"); builder.flag("-Xptxas=-v"); builder.define("CUDA_DEBUG", "1"); }
326 }
327}
328
329pub fn cuda_available() -> bool {
331 Command::new("nvcc").arg("--version").output().is_ok()
332}
333
334pub fn detect_cuda_arch() -> String {
336 let output = Command::new("nvidia-smi")
337 .args(["--query-gpu=compute_cap", "--format=csv,noheader"])
338 .output()
339 .expect("Failed to run nvidia-smi - make sure NVIDIA drivers are installed");
340
341 let full_output =
342 String::from_utf8(output.stdout).expect("nvidia-smi output is not valid UTF-8");
343
344 let arch = full_output
345 .lines()
346 .next()
347 .expect("nvidia-smi failed to return compute capability")
348 .trim()
349 .replace('.', ""); println!("cargo:rustc-env=CUDA_ARCH={}", arch);
353 env::set_var("CUDA_ARCH", &arch);
354 arch
355}
356
357pub fn nvcc_parallel_jobs() -> String {
359 let threads = std::thread::available_parallelism()
360 .map(|n| n.get())
361 .unwrap_or(1);
362
363 let threads = env::var("NVCC_THREADS")
364 .ok()
365 .and_then(|v| v.parse::<usize>().ok())
366 .unwrap_or(threads);
367
368 format!("-t{}", threads)
369}
370
371fn watch_glob(pattern: &str) {
373 for path in glob::glob(pattern).expect("Invalid glob pattern").flatten() {
374 if path.is_file() {
375 println!("cargo:rerun-if-changed={}", path.display());
376 }
377 }
378}