Also check that the cmdline file and sysctl file exist
[kconfig-hardened-check.git] / kernel_hardening_checker / __init__.py
index a1c8e06b3a87c1b2c615bc66e93cacd9da0b930c..67e02690bee12aa8b6934f5ec3c7ee91cab044c0 100644 (file)
@@ -8,24 +8,31 @@ Author: Alexander Popov <alex.popov@linux.com>
 This module performs input/output.
 """
 
 This module performs input/output.
 """
 
-# pylint: disable=missing-function-docstring,line-too-long,invalid-name,too-many-branches,too-many-statements
+# pylint: disable=missing-function-docstring,line-too-long,too-many-branches,too-many-statements
 
 
+import os
 import gzip
 import sys
 from argparse import ArgumentParser
 from typing import List, Tuple, Dict, TextIO
 import re
 import json
 import gzip
 import sys
 from argparse import ArgumentParser
 from typing import List, Tuple, Dict, TextIO
 import re
 import json
-from .__about__ import __version__
 from .checks import add_kconfig_checks, add_cmdline_checks, normalize_cmdline_options, add_sysctl_checks
 from .engine import StrOrNone, TupleOrNone, ChecklistObjType
 from .engine import print_unknown_options, populate_with_data, perform_checks, override_expected_value
 
 
 from .checks import add_kconfig_checks, add_cmdline_checks, normalize_cmdline_options, add_sysctl_checks
 from .engine import StrOrNone, TupleOrNone, ChecklistObjType
 from .engine import print_unknown_options, populate_with_data, perform_checks, override_expected_value
 
 
+# kernel-hardening-checker version
+__version__ = '0.6.6'
+
+
 def _open(file: str) -> TextIO:
 def _open(file: str) -> TextIO:
-    if file.endswith('.gz'):
-        return gzip.open(file, 'rt', encoding='utf-8')
-    return open(file, 'rt', encoding='utf-8')
+    try:
+        if file.endswith('.gz'):
+            return gzip.open(file, 'rt', encoding='utf-8')
+        return open(file, 'rt', encoding='utf-8')
+    except FileNotFoundError:
+        sys.exit(f'[!] ERROR: unable to open {file}, are you sure it exists?')
 
 
 def detect_arch(fname: str, archs: List[str]) -> Tuple[StrOrNone, str]:
 
 
 def detect_arch(fname: str, archs: List[str]) -> Tuple[StrOrNone, str]:
@@ -55,7 +62,7 @@ def detect_kernel_version(fname: str) -> Tuple[TupleOrNone, str]:
                 ver_str = parts[2].split('-', 1)[0]
                 ver_numbers = ver_str.split('.')
                 if len(ver_numbers) >= 3:
                 ver_str = parts[2].split('-', 1)[0]
                 ver_numbers = ver_str.split('.')
                 if len(ver_numbers) >= 3:
-                    if all(map(lambda x: x.isdigit(), ver_numbers)):
+                    if all(map(lambda x: x.isdecimal(), ver_numbers)):
                         return tuple(map(int, ver_numbers)), 'OK'
                 msg = f'failed to parse the version "{parts[2]}"'
                 return None, msg
                         return tuple(map(int, ver_numbers)), 'OK'
                 msg = f'failed to parse the version "{parts[2]}"'
                 return None, msg
@@ -162,6 +169,9 @@ def parse_kconfig_file(_mode: StrOrNone, parsed_options: Dict[str, str], fname:
 
 
 def parse_cmdline_file(mode: StrOrNone, parsed_options: Dict[str, str], fname: str) -> None:
 
 
 def parse_cmdline_file(mode: StrOrNone, parsed_options: Dict[str, str], fname: str) -> None:
+    if not os.path.isfile(fname):
+        sys.exit(f'[!] ERROR: unable to open {fname}, are you sure it exists?')
+
     with open(fname, 'r', encoding='utf-8') as f:
         line = f.readline()
         opts = line.split()
     with open(fname, 'r', encoding='utf-8') as f:
         line = f.readline()
         opts = line.split()
@@ -184,6 +194,9 @@ def parse_cmdline_file(mode: StrOrNone, parsed_options: Dict[str, str], fname: s
 
 
 def parse_sysctl_file(mode: StrOrNone, parsed_options: Dict[str, str], fname: str) -> None:
 
 
 def parse_sysctl_file(mode: StrOrNone, parsed_options: Dict[str, str], fname: str) -> None:
+    if not os.path.isfile(fname):
+        sys.exit(f'[!] ERROR: unable to open {fname}, are you sure it exists?')
+
     with open(fname, 'r', encoding='utf-8') as f:
         sysctl_pattern = re.compile(r"[a-zA-Z0-9/\._-]+ =.*$")
         for line in f.readlines():
     with open(fname, 'r', encoding='utf-8') as f:
         sysctl_pattern = re.compile(r"[a-zA-Z0-9/\._-]+ =.*$")
         for line in f.readlines():