1use std::{fs::read, path::PathBuf, str::FromStr};
2
3use eyre::Result;
4use openvm_sdk::{StdIn, F};
5use openvm_stark_backend::p3_field::FieldAlgebra;
6
7#[derive(Debug, Clone)]
14pub enum Input {
15 FilePath(PathBuf),
16 HexBytes(Vec<u8>),
17}
18
19impl FromStr for Input {
20 type Err = String;
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 if let Ok(bytes) = decode_hex_string(s) {
24 Ok(Input::HexBytes(bytes))
25 } else if PathBuf::from(s).exists() {
26 Ok(Input::FilePath(PathBuf::from(s)))
27 } else {
28 Err("Input must be a valid file path or a hex string of even length.".to_string())
29 }
30 }
31}
32
33pub fn decode_hex_string(s: &str) -> Result<Vec<u8>> {
34 let s = s.strip_prefix("0x").unwrap_or(s);
36 if s.len() % 2 != 0 {
37 return Err(eyre::eyre!("The hex string must be of even length"));
38 }
39 if !s.chars().all(|c| c.is_ascii_hexdigit()) {
40 return Err(eyre::eyre!("The hex string must consist of hex digits"));
41 }
42 if s.starts_with("02") {
43 if s.len() % 8 != 2 {
44 return Err(eyre::eyre!(
45 "If the hex value starts with 02, a whole number of 32-bit elements must follow"
46 ));
47 }
48 } else if !s.starts_with("01") {
49 return Err(eyre::eyre!("The hex value must start with 01 or 02"));
50 }
51 hex::decode(s).map_err(|e| eyre::eyre!("Invalid hex: {}", e))
52}
53
54pub fn read_bytes_into_stdin(stdin: &mut StdIn, bytes: &[u8]) -> Result<()> {
55 match bytes.first() {
57 Some(0x01) => {
58 stdin.write_bytes(&bytes[1..]);
59 Ok(())
60 }
61 Some(0x02) => {
62 let data = &bytes[1..];
63 if data.len() % 4 != 0 {
64 return Err(eyre::eyre!(
65 "Invalid input format: incorrect number of bytes"
66 ));
67 }
68 let mut fields = Vec::with_capacity(data.len() / 4);
69 for chunk in data.chunks_exact(4) {
70 let value = u32::from_le_bytes(chunk.try_into().unwrap());
71 fields.push(F::from_canonical_u32(value));
72 }
73 stdin.write_field(&fields);
74 Ok(())
75 }
76 _ => Err(eyre::eyre!(
77 "Invalid input format: the first byte must be 0x01 or 0x02"
78 )),
79 }
80}
81
82pub fn read_to_stdin(input: &Option<Input>) -> Result<StdIn> {
83 match input {
84 Some(Input::FilePath(path)) => {
85 let mut stdin = StdIn::default();
86 let bytes = read(path)?;
88 let json: serde_json::Value = serde_json::from_slice(&bytes)?;
89 json["input"]
90 .as_array()
91 .ok_or_else(|| eyre::eyre!("Input must be an array under 'input' key"))?
92 .iter()
93 .try_for_each(|inner| {
94 inner
95 .as_str()
96 .ok_or_else(|| eyre::eyre!("Each value must be a hex string"))
97 .and_then(|s| match decode_hex_string(s) {
98 Err(msg) => Err(eyre::eyre!("Invalid hex string: {}", msg)),
99 Ok(bytes) => {
100 read_bytes_into_stdin(&mut stdin, &bytes).expect("Fail: input validation accepted an input, but the deserialization rejected it");
101 Ok(())
102 }
103 })
104 })?;
105
106 Ok(stdin)
107 }
108 Some(Input::HexBytes(bytes)) => {
109 let mut stdin = StdIn::default();
110 read_bytes_into_stdin(&mut stdin, bytes)?;
111 Ok(stdin)
112 }
113 None => Ok(StdIn::default()),
114 }
115}