diff --git a/niri-config/src/lib.rs b/niri-config/src/lib.rs index 038d7e5a..ac133c62 100644 --- a/niri-config/src/lib.rs +++ b/niri-config/src/lib.rs @@ -114,6 +114,10 @@ struct Recursion(u8); struct Includes(Vec); #[derive(Default)] struct IncludeErrors(Vec); +// Used for recursive include detection. +// +// We don't *need* it because we have a recursion limit, but it makes for nicer error messages. +struct IncludeStack(HashSet); // Rather than listing all fields and deriving knuffel::Decode, we implement // knuffel::DecodeChildren by hand, since we need custom logic for every field anyway: we want to @@ -294,6 +298,16 @@ where }; let base = path.parent().map(Path::to_path_buf).unwrap_or_default(); + // Check for recursive include for a nicer error message. + let mut include_stack = ctx.get::().unwrap().0.clone(); + if !include_stack.insert(path.to_path_buf()) { + ctx.emit_error(DecodeError::missing( + node, + "recursive include (file includes itself)", + )); + continue; + } + // Store even if the include fails to read or parse, so it gets watched. includes.borrow_mut().0.push(path.to_path_buf()); @@ -316,6 +330,7 @@ where ctx.set(Recursion(recursion)); ctx.set(includes.clone()); ctx.set(include_errors.clone()); + ctx.set(IncludeStack(include_stack)); ctx.set(config.clone()); }); @@ -397,6 +412,7 @@ impl Config { let config = Rc::new(RefCell::new(Config::default())); let includes = Rc::new(RefCell::new(Includes(Vec::new()))); let include_errors = Rc::new(RefCell::new(IncludeErrors(Vec::new()))); + let include_stack = HashSet::from([path.to_path_buf()]); let part = knuffel::parse_with_context::( filename, @@ -407,6 +423,7 @@ impl Config { ctx.set(Recursion(0)); ctx.set(includes.clone()); ctx.set(include_errors.clone()); + ctx.set(IncludeStack(include_stack)); ctx.set(config.clone()); }, );