openvm_cuda_builder/
lib.rs

1use std::{env, path::Path, process::Command};
2
3/// CUDA builder configuration
4#[derive(Debug, Clone)]
5pub struct CudaBuilder {
6    include_paths: Vec<String>,
7    source_files: Vec<String>,
8    watch_paths: Vec<String>,
9    watch_globs: Vec<String>,
10    library_name: String,
11    cuda_arch: Vec<String>,
12    cuda_opt_level: Option<String>,
13    custom_flags: Vec<String>,
14    link_libraries: Vec<String>,
15    link_search_paths: Vec<String>,
16}
17
18impl Default for CudaBuilder {
19    fn default() -> Self {
20        let mut link_search_paths = Vec::new();
21        if let Ok(cuda_lib_dir) = env::var("CUDA_LIB_DIR") {
22            link_search_paths.push(cuda_lib_dir);
23        } else {
24            link_search_paths.push("/usr/local/cuda/lib64".to_string());
25        }
26
27        Self {
28            include_paths: Vec::new(),
29            source_files: Vec::new(),
30            watch_paths: vec!["build.rs".to_string()],
31            watch_globs: Vec::new(),
32            library_name: String::new(),
33            cuda_arch: Vec::new(),
34            cuda_opt_level: None,
35            custom_flags: vec![
36                "--std=c++17".to_string(),
37                "--expt-relaxed-constexpr".to_string(),
38                "-Xfatbin=-compress-all".to_string(),
39                "--default-stream=per-thread".to_string(),
40            ],
41            link_libraries: vec!["cudart".to_string(), "cuda".to_string()],
42            link_search_paths,
43        }
44    }
45}
46
47impl CudaBuilder {
48    /// Create a new CudaBuilder
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Set the library name (useful when cloning from a template)
54    pub fn library_name(mut self, name: &str) -> Self {
55        self.library_name = name.to_string();
56        self
57    }
58
59    /// Add include path
60    pub fn include<P: AsRef<Path>>(mut self, path: P) -> Self {
61        let path_str = path.as_ref().to_string_lossy().to_string();
62        self.include_paths.push(path_str.clone());
63        self.watch_paths.push(path_str);
64        self
65    }
66
67    /// Add include path from another crate's exported include
68    pub fn include_from_dep(mut self, dep_env_var: &str) -> Self {
69        if let Ok(path) = env::var(dep_env_var) {
70            self.include_paths.push(path);
71        }
72        self
73    }
74
75    /// Add source file
76    pub fn file<P: AsRef<Path>>(mut self, path: P) -> Self {
77        let path_str = path.as_ref().to_string_lossy().to_string();
78        self.source_files.push(path_str.clone());
79        self.watch_paths.push(path_str);
80        self
81    }
82
83    /// Add multiple source files
84    pub fn files<P: AsRef<Path>, I: IntoIterator<Item = P>>(mut self, paths: I) -> Self {
85        for path in paths {
86            let path_str = path.as_ref().to_string_lossy().to_string();
87            self.source_files.push(path_str.clone());
88            self.watch_paths.push(path_str);
89        }
90        self
91    }
92
93    /// Add multiple source files matching a glob pattern
94    pub fn files_from_glob(mut self, pattern: &str) -> Self {
95        self.watch_globs.push(pattern.to_string());
96        for path in glob::glob(pattern).expect("Invalid glob pattern").flatten() {
97            if path.is_file() && path.extension().is_some_and(|ext| ext == "cu") {
98                self.source_files.push(path.to_string_lossy().to_string());
99            }
100        }
101        self
102    }
103
104    /// Watch a specific path for changes
105    pub fn watch<P: AsRef<Path>>(mut self, path: P) -> Self {
106        self.watch_paths
107            .push(path.as_ref().to_string_lossy().to_string());
108        self
109    }
110
111    /// Watch paths matching a glob pattern
112    pub fn watch_glob(mut self, pattern: &str) -> Self {
113        self.watch_globs.push(pattern.to_string());
114        self
115    }
116
117    /// Set CUDA architecture (e.g., "75", "80")
118    pub fn cuda_arch(mut self, arch: &str) -> Self {
119        self.cuda_arch = vec![arch.to_string()];
120        self
121    }
122
123    /// Set multiple CUDA architectures  
124    pub fn cuda_archs(mut self, archs: Vec<&str>) -> Self {
125        self.cuda_arch = archs.iter().map(|s| s.to_string()).collect();
126        self
127    }
128
129    /// Set CUDA optimization level (0-3)
130    pub fn cuda_opt_level(mut self, level: u8) -> Self {
131        self.cuda_opt_level = Some(level.to_string());
132        self
133    }
134
135    /// Add custom compiler flag
136    pub fn flag(mut self, flag: &str) -> Self {
137        self.custom_flags.push(flag.to_string());
138        self
139    }
140
141    /// Add library to link
142    pub fn link_lib(mut self, lib: &str) -> Self {
143        self.link_libraries.push(lib.to_string());
144        self
145    }
146
147    /// Add library search path
148    pub fn link_search<P: AsRef<Path>>(mut self, path: P) -> Self {
149        self.link_search_paths
150            .push(path.as_ref().to_string_lossy().to_string());
151        self
152    }
153
154    /// Build the CUDA library
155    pub fn build(self) {
156        // Validation
157        self.validate();
158
159        // Set up rerun conditions
160        self.setup_rerun_conditions();
161
162        // Get or detect CUDA architecture
163        let cuda_archs = self.get_cuda_arch();
164
165        // Create cc::Build
166        let mut builder = cc::Build::new();
167        builder.cuda(true);
168
169        // Handle CUDA_DEBUG=1
170        self.handle_debug_shortcuts(&mut builder);
171
172        // Get optimization level
173        let cuda_opt_level = self.get_cuda_opt_level();
174
175        // Add include paths
176        for include in &self.include_paths {
177            builder.include(include);
178        }
179
180        // Add CUDA_PATH include if available
181        if let Ok(cuda_path) = env::var("CUDA_PATH") {
182            builder.include(format!("{}/include", cuda_path));
183        }
184
185        // Add custom flags
186        for flag in &self.custom_flags {
187            builder.flag(flag);
188        }
189
190        // Add SASS code for each architecture
191        for arch in &cuda_archs {
192            builder
193                .flag("-gencode")
194                .flag(format!("arch=compute_{},code=sm_{}", arch, arch));
195        }
196
197        // Add PTX for the highest architecture (forward compatibility)
198        // This allows the code to run on future GPUs
199        if let Some(max_arch) = cuda_archs.iter().max() {
200            builder.flag("-gencode").flag(format!(
201                "arch=compute_{},code=compute_{}",
202                max_arch, max_arch
203            ));
204        }
205
206        // Add parallel jobs flag
207        builder.flag(nvcc_parallel_jobs());
208
209        // Set optimization and debug flags
210        if cuda_opt_level == "0" {
211            builder.debug(true).flag("-O0");
212        } else {
213            builder
214                .debug(false)
215                .flag(format!("--ptxas-options=-O{}", cuda_opt_level));
216        }
217
218        // Add source files
219        for file in &self.source_files {
220            builder.file(file);
221        }
222
223        // Compile
224        builder.compile(&self.library_name);
225    }
226
227    /// Validate the builder configuration
228    fn validate(&self) {
229        if self.library_name.is_empty() {
230            panic!(
231                "Library name must be set using .library_name(\"name\") before calling .build()"
232            );
233        }
234
235        if self.source_files.is_empty() {
236            panic!("At least one source file must be added using .file() or .files() before calling .build()");
237        }
238
239        // Validate that source files exist (optional, but helpful)
240        for file in &self.source_files {
241            if !Path::new(file).exists() {
242                eprintln!("cargo:warning=CUDA source file does not exist: {}", file);
243            }
244        }
245
246        // Validate include paths exist (optional warning)
247        for include in &self.include_paths {
248            if !Path::new(include).exists() {
249                eprintln!("cargo:warning=Include path does not exist: {}", include);
250            }
251        }
252    }
253
254    pub fn emit_link_directives(&self) {
255        for path in &self.link_search_paths {
256            println!("cargo:rustc-link-search=native={}", path);
257        }
258        for lib in &self.link_libraries {
259            println!("cargo:rustc-link-lib={}", lib);
260        }
261    }
262
263    fn setup_rerun_conditions(&self) {
264        // Standard rerun conditions
265        println!("cargo:rerun-if-env-changed=CUDA_ARCH");
266        println!("cargo:rerun-if-env-changed=CUDA_OPT_LEVEL");
267        println!("cargo:rerun-if-env-changed=CUDA_DEBUG");
268        println!("cargo:rerun-if-env-changed=NVCC_THREADS");
269
270        // Watch specific paths
271        for path in &self.watch_paths {
272            println!("cargo:rerun-if-changed={}", path);
273        }
274
275        // Watch glob patterns
276        for pattern in &self.watch_globs {
277            watch_glob(pattern);
278        }
279    }
280
281    fn get_cuda_arch(&self) -> Vec<String> {
282        if !self.cuda_arch.is_empty() {
283            return self.cuda_arch.clone();
284        }
285
286        // Check environment variable
287        if let Ok(env_archs) = env::var("CUDA_ARCH") {
288            return env_archs
289                .split(',')
290                .map(|s| s.trim().to_string())
291                .filter(|s| !s.is_empty())
292                .collect();
293        }
294
295        // Auto-detect current GPU
296        vec![detect_cuda_arch()]
297    }
298
299    fn get_cuda_opt_level(&self) -> String {
300        if let Some(level) = &self.cuda_opt_level {
301            return level.clone();
302        }
303
304        env::var("CUDA_OPT_LEVEL").unwrap_or_else(|_| "3".to_string())
305    }
306
307    fn handle_debug_shortcuts(&self, builder: &mut cc::Build) {
308        if env::var("CUDA_DEBUG").map(|v| v == "1").unwrap_or(false) {
309            env::set_var("CUDA_OPT_LEVEL", "0");
310            env::set_var("CUDA_LAUNCH_BLOCKING", "1");
311            env::set_var("RUST_BACKTRACE", "full");
312            env::set_var("CUDA_ENABLE_COREDUMP_ON_EXCEPTION", "1");
313            env::set_var("CUDA_DEVICE_WAITS_ON_EXCEPTION", "1");
314
315            println!("cargo:warning=CUDA_DEBUG=1 → Enabling comprehensive debugging:");
316            println!("cargo:warning=  → CUDA_OPT_LEVEL=0 (no optimization)");
317            println!("cargo:warning=  → CUDA_LAUNCH_BLOCKING=1 (synchronous kernels)");
318            println!("cargo:warning=  → Line info and device debug symbols enabled");
319            println!("cargo:warning=  → CUDA_DEBUG macro defined for preprocessor");
320
321            builder.flag("-G"); // Device debug symbols
322            builder.flag("-Xcompiler=-fno-omit-frame-pointer"); // Better stack traces
323            builder.flag("-Xptxas=-v"); // Verbose PTX compilation
324            builder.define("CUDA_DEBUG", "1"); // Define CUDA_DEBUG macro
325        }
326    }
327}
328
329/// Check if CUDA is available on the system
330pub fn cuda_available() -> bool {
331    Command::new("nvcc").arg("--version").output().is_ok()
332}
333
334/// Detect CUDA architecture using nvidia-smi
335pub fn detect_cuda_arch() -> String {
336    let output = Command::new("nvidia-smi")
337        .args(["--query-gpu=compute_cap", "--format=csv,noheader"])
338        .output()
339        .expect("Failed to run nvidia-smi - make sure NVIDIA drivers are installed");
340
341    let full_output =
342        String::from_utf8(output.stdout).expect("nvidia-smi output is not valid UTF-8");
343
344    let arch = full_output
345        .lines()
346        .next()
347        .expect("nvidia-smi failed to return compute capability")
348        .trim()
349        .replace('.', ""); // Convert "7.5" to "75"
350
351    // Set both cargo env and process env
352    println!("cargo:rustc-env=CUDA_ARCH={}", arch);
353    env::set_var("CUDA_ARCH", &arch);
354    arch
355}
356
357/// Calculate optimal number of parallel NVCC jobs
358pub fn nvcc_parallel_jobs() -> String {
359    let threads = std::thread::available_parallelism()
360        .map(|n| n.get())
361        .unwrap_or(1);
362
363    let threads = env::var("NVCC_THREADS")
364        .ok()
365        .and_then(|v| v.parse::<usize>().ok())
366        .unwrap_or(threads);
367
368    format!("-t{}", threads)
369}
370
371/// Watch files matching a glob pattern
372fn watch_glob(pattern: &str) {
373    for path in glob::glob(pattern).expect("Invalid glob pattern").flatten() {
374        if path.is_file() {
375            println!("cargo:rerun-if-changed={}", path.display());
376        }
377    }
378}