#!/usr/bin/env python
# -*- coding: utf-8 -*-
##
#
# Vaisala software source code file
#
# Copyright (c) Vaisala Oyj 2015. All rights reserved.
#
##
"""
Configures PostgreSQL server's pg_hba.conf according to radar distribution
standard.
"""
import datetime
import sys
import os
import shutil
import grp
import pwd
import logging
import argparse
import subprocess

_tool_name = 'rsw-postgresql-configure-auth'
logger = logging.getLogger(_tool_name)

# Check for Python3
if sys.version_info >= (3, 0):
    pg_hba_file_perms = 0o600
else:
    pg_hba_file_perms = 0600


# Authentication configuration
def configure_postgres_auth(file_path):
    # Enable password authentication
    required_entries = [
        "local   all             postgres                                peer",
        "local   all             all                                     md5",
        "host    all             all              127.0.0.1/32           md5",
        "host    all             all              ::1/128                md5",
    ]

    # Ensure the application doesn't try to authenticate via ident
    banned_entries = [
        "local   all             all                                     peer",
        "host    all             all              127.0.0.1/32           ident",
        "host    all             all              ::1/128                ident"
    ]

    def text_entry_to_tuple(text):
        return [part
                for part in text.split()
                if len(part) > 0]

    expected_entries = [text_entry_to_tuple(text_entry) for text_entry in required_entries]
    banned_entries = [text_entry_to_tuple(text_entry) for text_entry in banned_entries]

    lines = []
    entries = []
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip().startswith("#"):
                entries.append(text_entry_to_tuple(line))
            lines.append(line)


    new_file_path = file_path + ".new"
    now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    changes = 0
    with open(new_file_path, 'w') as f:
        for line in lines:
            if not line.startswith("#"):
                entry = text_entry_to_tuple(line)
                if entry in banned_entries:
                    f.write("\n# Automatically commented out by " + _tool_name + " at " + now + ":\n")
                    f.write("#" + line)
                    changes += 1
                elif entry in expected_entries:
                    expected_entries.remove(entry)
                else:
                    f.write(line)
            else:
                f.write(line)

        for index, expected_entry in enumerate(expected_entries):
            if expected_entry not in entries:
                f.write("\n# Automatically added by " + _tool_name + " at " + now + ":\n")
                f.write(required_entries[index])
                f.write("\n")
                changes += 1

    if changes > 0:
        shutil.copyfile(file_path, file_path + "." + now)
        os.rename(new_file_path, file_path)
        os.chown(file_path, pwd.getpwnam("postgres")[2], grp.getgrnam("postgres")[2])
        os.chmod(file_path, pg_hba_file_perms)
        logger.info("PostgreSQL authentication configuration changed. " +
                    "Restart the service to take modifications into use.")
    else:
        os.unlink(new_file_path)
        logger.info("PostgreSQL authentication configuration not changed.")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--pg-hba-file", dest="pg_hba_file",
                       metavar="FILE")
    group.add_argument("--test", dest="test", action="store_true",
                       help="Test that this tool is executable.")
    parser.add_argument("-d", "--debug", dest="debug", action="store_true",
                        help="Enable debug logging.")
    args = parser.parse_args()

    if args.test:
        sys.exit(0)

    def setup_logging(debug_enabled):
        root = logging.getLogger()
        if debug_enabled:
            root.setLevel(logging.DEBUG)
        else:
            root.setLevel(logging.INFO)
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        root.addHandler(ch)
    setup_logging(args.debug)

    if not os.path.exists(args.pg_hba_file):
        parser.error(u"'{}' doesn't exist!".format(args.pg_hba_file))
    if not os.path.isfile(args.pg_hba_file):
        parser.error(u"'{}' is not a file!".format(args.pg_hba_file))

    configure_postgres_auth(args.pg_hba_file)
