openvm_scripts/
lib.rs

1use std::{
2    collections::BTreeSet,
3    env,
4    path::{Path, PathBuf},
5};
6
7use walkdir::WalkDir;
8
9pub fn is_cuda_path(path: &Path) -> bool {
10    path.to_string_lossy().to_lowercase().contains("cuda")
11}
12
13pub fn has_cuda_files(path: &Path) -> bool {
14    WalkDir::new(path)
15        .into_iter()
16        .filter_map(Result::ok)
17        .any(|e| {
18            e.file_type().is_file()
19                && matches!(
20                    e.path().extension().and_then(|s| s.to_str()),
21                    Some("cuh") | Some("cu")
22                )
23        })
24}
25
26pub fn find_cuda_include_dirs(workspace_root: &Path) -> Vec<PathBuf> {
27    let mut include_dirs: BTreeSet<PathBuf> = BTreeSet::new();
28
29    for entry in WalkDir::new(workspace_root)
30        .follow_links(false)
31        .into_iter()
32        .filter_map(Result::ok)
33        .filter(|e| e.file_type().is_dir() && e.file_name() == "include")
34    {
35        let include_dir = entry.path().to_path_buf();
36
37        if include_dir.components().any(|c| c.as_os_str() == "target") {
38            continue;
39        }
40
41        if is_cuda_path(&include_dir) || has_cuda_files(&include_dir) {
42            include_dirs.insert(include_dir);
43        }
44    }
45
46    include_dirs.into_iter().collect()
47}
48
49pub fn find_files_with_extension(root: &Path, extension: &str) -> Vec<PathBuf> {
50    WalkDir::new(root)
51        .into_iter()
52        .filter_map(Result::ok)
53        .filter(|e| {
54            e.file_type().is_file()
55                && e.path().extension().and_then(|s| s.to_str()) == Some(extension)
56        })
57        .map(|e| e.path().to_path_buf())
58        .collect()
59}
60
61pub fn get_cuda_dep_common_include_dirs() -> Vec<PathBuf> {
62    if let Some(val) = option_env!("DEP_CUDA_COMMON_INCLUDE") {
63        env::split_paths(val)
64            .filter(|p| !p.as_os_str().is_empty())
65            .collect()
66    } else {
67        vec![]
68    }
69}