#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2022 Greenbone AG
#
# SPDX-License-Identifier: AGPL-3.0-or-later
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Escalator method script: SMB.

import re
import sys
import subprocess


def smb_error_print(message, stdout, stderr):
    print(message, file=sys.stderr)
    if stderr:
        print(stderr, file=sys.stderr)
    if stdout:
        print(stdout, file=sys.stderr)


def smb_call(auth_path, share, command, extra_args):
    args = ["smbclient"] + extra_args + ["-A", auth_path, share, "-c", command]

    retries = 10
    stdout = ''
    stderr = ''
    retry_regex = re.compile("^Connection to .* failed")

    while retries > 0:
        retries = retries - 1
        try:
            p = subprocess.Popen(args,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            p.wait()
            stdout_bytes, stderr_bytes = p.communicate()
        except FileNotFoundError as err:
            print("Error: Could not find smbclient", file=sys.stderr)
            sys.exit(1)
        except Exception as err:
            print("%s running smbclient: %s" % (type(err).__name__, err),
                  file=sys.stderr)
            sys.exit(1)

        stdout = stdout_bytes.decode("UTF-8")
        stderr = stderr_bytes.decode("UTF-8")
        ret = p.returncode

        # Unexpected exit code from smbclient
        if ret != 0 and ret != 1:
            smb_error_print("smbclient exited with code %d" % (ret),
                            stdout, stderr)
            sys.exit(1)

        # Success or an error that is not expected to occur
        #  just temporarily
        if ret == 0 or retry_regex.match(stdout) is None:
            return p.returncode, stdout, stderr

    smb_error_print("smbclient call failed:", stdout, stderr)
    sys.exit(1)


def smb_dir_exists(auth_path, share, check_dir, extra_args):
    command = "cd \"%s\"" % (check_dir)
    rc, stdout, stderr = smb_call(auth_path, share, command, extra_args)

    if rc == 0:
        return True
    elif (stdout.endswith("NT_STATUS_OBJECT_NAME_NOT_FOUND\n")
          or stdout.endswith("NT_STATUS_OBJECT_PATH_NOT_FOUND\n")):
        return False
    else:
        smb_error_print("Error checking directory %s" % check_dir,
                        stdout, stderr)
        sys.exit(1)


def smb_mkdir(auth_path, share, check_dir, extra_args):
    command = "mkdir \"%s\"" % (check_dir)
    rc, stdout, stderr = smb_call(auth_path, share, command, extra_args)

    if rc == 0:
        if not smb_dir_exists(auth_path, share, check_dir, extra_args):
            print("Could not create directory %s" % check_dir,
                  file=sys.stderr)
            sys.exit(1)
        else:
            print("Created directory %s" % check_dir)
    else:
        smb_error_print("Error creating directory %s" % check_dir,
                        stdout, stderr)
        sys.exit(1)


def smb_put(auth_path, share, report_path, dest_path, extra_args):
    command = "put \"%s\" \"%s\"" % (report_path, dest_path)
    rc, stdout, stderr = smb_call(auth_path, share, command, extra_args)

    if rc == 0:
        print("Report copied to directory %s" % dest_path)
    else:
        smb_error_print("Error copying to %s:" % dest_path,
                        stdout, stderr)
        sys.exit(1)


def main():
    if len(sys.argv) != 6:
        print("usage: %s <share> <dest_path> <max_protocol> <auth_path> <report_path>"
              % sys.argv[0], file=sys.stderr)
        sys.exit(1)

    share = sys.argv[1]
    dest_path = sys.argv[2]
    
    extra_args = []
    if sys.argv[3]:
        extra_args.append("-m")
        extra_args.append(sys.argv[3])

    auth_path = sys.argv[4]
    report_path = sys.argv[5]

    create_dirs = True

    # replace forward slashes with backslashes in destination file path
    #  so they are always handled as path separator.
    dest_path = dest_path.replace('/', '\\')

    # get list of subdirectory paths
    dest_parts = dest_path.split('\\')
    dest_subpaths = []
    make_dest = ''
    for i in range(0, len(dest_parts)):
        if dest_parts[i].endswith('.'):
            print("File or subdirectory names must not be '.', '..' or"
                  " end with a dot",
                  file=sys.stderr)
            sys.exit(1)

    if len(dest_parts) >= 2:
        for i in range(0, len(dest_parts) - 1):
            if dest_parts[i]:
                if make_dest:
                    make_dest += '\\' + dest_parts[i]
                else:
                    make_dest = dest_parts[i]
                dest_subpaths.append(make_dest)

    # Find first existing path
    first_existing_path_index = -1
    for i in range(len(dest_subpaths)-1, -1, -1):
        if smb_dir_exists(auth_path, share, dest_subpaths[i], extra_args):
            first_existing_path_index = i
            break

    # Create missing directories
    if create_dirs:
        for i in range(first_existing_path_index + 1, len(dest_subpaths)):
            smb_mkdir(auth_path, share, dest_subpaths[i], extra_args)

    smb_put(auth_path, share, report_path, dest_path, extra_args)


if __name__ == '__main__':
    main()
