#!/usr/bin/env python
# -*- coding: utf-8; indent: 4 -*-
##
#
# Vaisala software source code file
#
# Copyright (c) Vaisala Oyj 2015. All rights reserved.
#
##
"""scan-updater-service is a background daemon that watches a directory tree
and updates the scan database accordingly.

scan-updater-service shares the database with scan-http-service. The services
work in tandem: scan-updater-service is responsible for tracking what happens
on disk while scan-http-service handles the HTTP traffic.
"""
from __future__ import print_function

import collections
import datetime
import logging
import logging.handlers
import os
import posix
import sqlite3
import sys
import threading
import time
import traceback

import pyinotify

from operator import attrgetter

from scan_service import scan_db
from scan_service.hybrid_linking import link_hybrids
from scan_service.scan_service import read_scan_in_subprocess
from scan_service.read_scan import ReadScanException
from scan_service.scan_updater_utils import find_raw_files_recursively, BatchStatus
from scan_service.scan_updater_utils import iris_raw_product_name_comparator
from scan_service.utils import compress_simple
from scan_service.utils import set_proc_name, strip_prefix
from scan_service.utils import uuid_pair_to_str, lg


try:
    from systemd import journal, daemon
except:
    pass


HybridInfo = collections.namedtuple('HybridInfo', '''
    task_name
    minor_task_id
    cardinality
''')


# Only one updater runs at a time - either the periodic one or
# inotify-triggered.
_UPDATE_LOCK = threading.RLock()


ONE_HOUR_TIMEDELTA = datetime.timedelta(hours=1)


def is_valid_raw(file_path):
    filename = os.path.basename(file_path)
    if filename.startswith("XXX"):
        return False

    file_extension = os.path.splitext(file_path)[-1]
    if not file_extension.strip('.').startswith('RAW'):
        return False

    return True


def extract_hybrid_info(scan):
    task_name = None
    minor_task_id = None
    cardinality = None

    for sweep in scan.sweeps:
        c = sweep.config.iris_hybrid_cardinality
        if c is not None and c > 1:
            cardinality = sweep.config.iris_hybrid_cardinality
            minor_task_id = sweep.config.iris_sweep_task_id[-1]
            task_name = scan.task_name[:-2]
        else:
            cardinality = None
            minor_task_id = None
            task_name = scan.task_name

    if None not in (cardinality, task_name, minor_task_id):
        return HybridInfo(task_name=task_name,
                          minor_task_id=minor_task_id,
                          cardinality=cardinality)
    else:
        return None


def extract_sweep_metadata(sweep):
    logger = lg('extract_sweep_metadata')

    sweep_type = sweep.config.type
    fixed_angle = sweep.config.fixed_angle
    prf = None
    radio_frequency = None
    unambiguous_velocity = None
    unambiguous_range = None
    range = -1

    def return_or_assert_same(current, candidate, name):
        if candidate is None:
            logger.warn("Value candidate with name %s is None" % name)
            return current

        if current is not None and candidate is not None:
            if current != candidate:
                logger.warn("Differing %s between rays" % name)
                return current

        return candidate

    for ray in sweep.rays:
        conf = ray.config
        prf = return_or_assert_same(prf, conf.prf, "PRF")
        radio_frequency = return_or_assert_same(radio_frequency, conf.rf_freq, "radio frequency")
        unambiguous_velocity = return_or_assert_same(unambiguous_velocity, conf.nyquist_vel,
                                                     "unambiguous velocity")
        unambiguous_range = return_or_assert_same(unambiguous_range, conf.unamb_range,
                                                  "unambiguous range")

        for moment in ray.moments:
            range_mask = moment.range_mask
            moment_range = range_mask.start + range_mask.count * range_mask.step
            if moment_range > range:
                range = moment_range

    assert range != -1, "No range could be calculated"

    return {
        'type': sweep_type,
        'fixedAngle': fixed_angle,
        'range': range,
        'pulseRepetitionFrequency': prf,
        'radioFrequency': radio_frequency,
        'unambiguousVelocity': unambiguous_velocity,
        'unambiguousRange': unambiguous_range,
    }


def collect_moments(scan):
    result = set()
    for sweep in scan.sweeps:
        for ray in sweep.rays:
            for moment in ray.moments:
                result.add(strip_prefix(moment.name, 'RD_'))
    return sorted(list(result))


def read_and_insert_one(db_conn, volume_dir, relative_path):
    """Reads a scan from a file and inserts it into the scan database."""
    logger = lg('read_and_insert_one')
    absolute_path = os.path.join(volume_dir, relative_path)
    scan, scan_reading_took = read_scan_in_subprocess([absolute_path], metadata_only=True)

    file_name = os.path.basename(relative_path)
    three_letter_code = file_name[0:3].upper()

    hybrid_info = extract_hybrid_info(scan)
    if hybrid_info is None:
        task_name = scan.task_name
        minor_task_id = None
        hybrid_cardinality = None
    else:
        task_name = hybrid_info.task_name
        minor_task_id = hybrid_info.minor_task_id
        hybrid_cardinality = hybrid_info.cardinality

    sweep_metadatas = []
    for sweep in scan.sweeps:
        sweep_metadatas.append(extract_sweep_metadata(sweep))
    sweep_metadatas_compressed = compress_simple(sweep_metadatas)

    start_insert = datetime.datetime.now()
    cursor = db_conn.cursor()

    # On a full database rebuild, we hit cases where the time-triggered
    # update has already inserted the scan.
    if cursor.execute("SELECT rel_path FROM scans WHERE rel_path = :rel_path",
                      {'rel_path': relative_path}).fetchone() is not None:
        logger.info(u"Scan {0} already in the database.".format(absolute_path))
        cursor.close()
        return scan

    cursor.execute(scan_db.INSERT_SCAN_STATEMENT, {
        'id': uuid_pair_to_str(scan.input_file_sha1_hash),
        'task_name': task_name,
        'hybrid_minor_task_id': minor_task_id,
        'hybrid_cardinality': hybrid_cardinality,
        'site_name': scan.site_name,
        'radar_id': three_letter_code,
        'latitude': scan.position.latitude,
        'longitude': scan.position.longitude,
        'altitude': scan.position.altitude,
        'timestamp': scan.timestamp,
        'rel_path': relative_path,
        'moments': ",".join(collect_moments(scan)),
        'sweep_metadatas': sqlite3.Binary(sweep_metadatas_compressed),
    })

    db_conn.commit()
    cursor.close()
    logger.info(
        u"Scan {0} read and inserted in {1} + {2} s.".format(
            absolute_path,
            scan_reading_took,
            (datetime.datetime.now() - start_insert).total_seconds()))
    return scan


def remove_one(db_conn, rel_path):
    logger = lg('remove_one')
    cursor = db_conn.cursor()
    row = cursor.execute("SELECT id, hybrid_id FROM scans WHERE rel_path = :rel_path",
                         {'rel_path': rel_path}).fetchone()
    if row is None:
        logger.warn(u"Couldn't find a scan for relative path {0}.".format(rel_path))
        cursor.close()
        return

    logger.info(u"Removing scan by id '{0}'...".format(row.id))
    result = cursor.execute("DELETE FROM scans WHERE id = :id", {'id': row.id})
    if row.hybrid_id is not None:
        cursor.execute("UPDATE scans SET hybrid_id = NULL WHERE hybrid_id = :hybrid_id",
                       {'hybrid_id': row.hybrid_id})
        cursor.execute("DELETE FROM hybrids WHERE id = :id", {'id': row.hybrid_id})
    db_conn.commit()
    if result.rowcount > 0:
        logger.info(u"Removed scan '{0}'.".format(rel_path))
    else:
        logger.warn(u"Couldn't remove scan {0}: affected row count {1}".format(
            rel_path, result.rowcount))
    cursor.close()


def remove_scans(db_conn, relative_paths):
    if len(relative_paths) == 0:
        return

    for rel_path in relative_paths:
        try:
            remove_one(db_conn, rel_path)
        except KeyboardInterrupt, kbi:
            raise kbi
        except:
            lg('remove_scans').warn(u"Error removing a scan: {0}".format(traceback.format_exc()))


def add_scans(db_conn, volume_dir, relative_paths, scan_budget=100):
    """Adds missing scans to the database.

    Args:
        db_conn: Database connection object.
        volume_dir: Path to the volume directory root. Volumes have relative
            paths recorded and that relative path is then added to this volume
            directory root path to get the full path to a volume.
        relative_paths: List of paths to the volume files to add to the
            database.
        scan_budget: Only add this many scans. Useful for when the system is in
            operational use and one wants to ensure that the newest scans are
            regularly added to the database.

    Returns:
        A tuple of (batch_status, [list_of_scans_added]).

    """
    if len(relative_paths) == 0:
        return BatchStatus.ALL_HANDLED, []

    logger = lg('add_scans')
    logger.info(
        u"Adding {0} scans...".format(len(relative_paths)))

    added = []
    errors = 0
    for rel_path in relative_paths:
        if len(added) > 0 and len(added) % 10 == 0:
            logger.info(
                u"Progress @ {0}: {1}/{2}...".format(
                    datetime.datetime.now(), len(added), len(relative_paths)))
        try:
            if len(added) == scan_budget:
                logger.info(u"Maximum scans per update run achieved ({0})".format(scan_budget))
                return BatchStatus.PARTLY_HANDLED, added

            scan = read_and_insert_one(db_conn, volume_dir, rel_path)
            added.append(scan)
        except KeyboardInterrupt, kbi:
            raise kbi
        except ReadScanException, rse:
            logger.warn(u"Error reading scan from file {0}: {1}".format(
                os.path.join(volume_dir, rel_path), str(rse)))
            errors += 1
        except:
            logger.warn(
                u"Error reading and inserting scan from file {0}: {1}".format(
                    os.path.join(volume_dir, rel_path), traceback.format_exc()))
            errors += 1

    logger.info(u"Added {0} of {1} missing scans to database.".format(
        len(added), len(relative_paths)))

    if errors > 0:
        return BatchStatus.ERROR, added
    else:
        return BatchStatus.ALL_HANDLED, added


def refresh_db(db_conn, volume_dir, full_hybrid_rebuild=False):
    """Runs one update cycle for the scan database.

    On each cycle, find all raw product files on disk and compare the list
    with the database. First remove all the files from the database that are
    not on disk and after that add all the scans that are on disk, but not in
    the database.

    Args:
        db_conn: Database connection object.
        volume_dir: Path to the volume directory root. Volumes have relative
            paths recorded and that relative path is then added to this volume
            directory root path to get the full path to a volume.
        full_hybrid_rebuild: All the hybrids in the database will be
            re-examined for linking.

    Returns:
        Two-tuple of BatchStatus enum-values that specifies whether all the
            scans that needed to be added were added and whether all the
            combinable hybrids were combined. This is used for launching
            another refresh round when we are rebuilding a database instead of
            sleeping a full refresh period.
    """
    cursor = db_conn.cursor()
    in_db = set()
    for row in cursor.execute("SELECT rel_path FROM scans"):
        in_db.add(row.rel_path)
    cursor.close()

    all_raw_abs_paths = find_raw_files_recursively(volume_dir)
    raw_abs_paths = []
    skipped_raws = 0
    for abs_path in all_raw_abs_paths:
        if is_valid_raw(abs_path):
            raw_abs_paths.append(abs_path)
        else:
            skipped_raws += 1

    if skipped_raws > 0:
        lg('refresh_db').info(u"Skipping {} scan(s) with XXX site id.".format(skipped_raws))

    raw_rel_paths = [os.path.relpath(path, volume_dir) for path in raw_abs_paths]
    on_disk = set(raw_rel_paths)

    to_be_removed = in_db - on_disk
    to_be_added = on_disk - in_db

    # Remove first, otherwise we'll end up with duplicates if the operation is
    # actually a reorganization rather than a deletion or an insertion.
    remove_scans(db_conn, to_be_removed)

    # Order the scans so that newest scans are added first
    to_be_added_sorted = sorted(list(to_be_added),
                                cmp=iris_raw_product_name_comparator,
                                reverse=True)

    batch_size = 100
    add_status, added_scans = add_scans(db_conn, volume_dir,
                                        to_be_added_sorted, scan_budget=batch_size)

    if full_hybrid_rebuild is False:
        added_scans_sorted = sorted(added_scans, key=attrgetter('timestamp'))
        combine_status = link_hybrids(db_conn,
                                      (added_scans_sorted[0].timestamp - datetime.timedelta(hours=1),
                                       added_scans_sorted[-1].timestamp + datetime.timedelta(hours=1)),
                                      hybrid_budget=batch_size)
    else:
        combine_status = link_hybrids(db_conn,
                                      (datetime.datetime(year=1970, month=1, day=1),
                                       datetime.datetime.now() + datetime.timedelta(days=1)),
                                      hybrid_budget=99999)
    return add_status, combine_status


def _make_db_connection(db_file):
    result = sqlite3.connect(db_file, detect_types=sqlite3.PARSE_DECLTYPES)
    scan_db.create_db(result)
    result.row_factory = scan_db.namedtuple_factory
    # http://stackoverflow.com/a/27165929
    result.execute("PRAGMA busy_timeout = %i" % 15000)
    return result

def updater_loop(volume_dir, db_file, interval_secs=600, started_callback=None):
    """Runs the main scan-updater-service daemon loop.

    Args:
        volume_dir: Path to the volume directory root.
        db_file: Scan database file name / path.
        interval_secs: Interval between regular update runs. Default is 600."""
    assert volume_dir is not None
    assert db_file is not None

    logger = lg('updater_loop')
    logger.info(u"Opening database in {0}...".format(db_file))
    with _make_db_connection(db_file) as db_conn:
        cursor = db_conn.cursor()
        scan_count = cursor.execute("SELECT COUNT(id) AS scan_count FROM scans WHERE hybrid_id IS NULL") \
                           .fetchone().scan_count
        hybrid_count = cursor.execute("SELECT COUNT(id) AS hybrid_count FROM hybrids") \
                             .fetchone().hybrid_count
        unlinked_scan_count = cursor.execute(
            """SELECT COUNT(id) AS unlinked_scan_count FROM scans
               WHERE hybrid_minor_task_id IS NOT NULL AND hybrid_id IS NULL""") \
                             .fetchone().unlinked_scan_count
        cursor.close()
        logger.info(u"Database ready; {} scans, {} hybrids, {} unlinked hybrid parts.".format(
            scan_count, hybrid_count, unlinked_scan_count))

    if started_callback is not None:
        started_callback()

    while True:
        with _UPDATE_LOCK:
            start_update = datetime.datetime.now()
            logger.debug("Starting time-triggered update run...")
            with _make_db_connection(db_file) as db_conn:
                added, combined = refresh_db(db_conn, volume_dir, full_hybrid_rebuild=True)
            logger.debug("Time-triggered update run finished in {0} s.".format
                         ((datetime.datetime.now() - start_update).total_seconds()))
        if added == BatchStatus.PARTLY_HANDLED or combined == BatchStatus.PARTLY_HANDLED:
            logger.info(
                u"Launching a new update run as all the scans were not added on the previous one.")
        else:
            with _make_db_connection(db_file) as db_conn:
                cursor = db_conn.cursor()
                unlinked_scan_count = cursor.execute(
                    """SELECT COUNT(id) AS unlinked_scan_count FROM scans
                    WHERE hybrid_minor_task_id IS NOT NULL AND hybrid_id IS NULL""") \
                                            .fetchone().unlinked_scan_count

                if unlinked_scan_count > 0:
                    logger.warn(u"There are {} unlinked hybrid scan parts on disk.".format(
                        unlinked_scan_count))
                    unlinked_scans = cursor.execute(
                        """SELECT radar_id, task_name, timestamp, hybrid_minor_task_id, hybrid_cardinality FROM scans
                        WHERE hybrid_minor_task_id IS NOT NULL AND hybrid_id IS NULL ORDER BY timestamp ASC"""
                    )

                    for row in unlinked_scans:
                        logger.info(
                            u"Hybrid scan part {radar_id} {task_name} {timestamp:%Y-%m-%dT%H:%M:%S} {hybrid_minor_task_id}/{hybrid_cardinality} not linked.".format(**row._asdict()))
                cursor.close()

            time.sleep(float(interval_secs))


def relative_to_base(base, path):
    normalized_base = os.path.abspath(base)
    normalized_path = os.path.abspath(path)
    return os.path.relpath(path, base)


class InotifyEventHandler(pyinotify.ProcessEvent):
    def __init__(self, db_file, volume_dir, *args, **kwargs):
        self.db_file = db_file
        self.volume_dir = volume_dir
        self.logger = lg(self.__class__.__name__)

    def handle_incoming(self, event):
        with _UPDATE_LOCK:
            try:
                with _make_db_connection(self.db_file) as db_conn:
                    scan = read_and_insert_one(db_conn, self.volume_dir,
                                               relative_to_base(self.volume_dir, event.pathname))
                    link_hybrids(db_conn,
                                 (scan.timestamp - datetime.timedelta(hours=1),
                                  scan.timestamp + datetime.timedelta(hours=1)),
                                 hybrid_budget=20)
            except:
                self.logger.warn(
                    u"Error adding scan '{0}' or combining hybrids: {1}".format(
                        os.path.join(event.pathname), traceback.format_exc()))

    def handle_outgoing(self, event):
        with _UPDATE_LOCK:
            try:
                with _make_db_connection(self.db_file) as db_conn:
                    remove_one(db_conn, relative_to_base(self.volume_dir, event.pathname))
            except:
                self.logger.warn(
                    u"Error removing scan '{0}' or combining hybrids: {1}".format(
                        os.path.join(event.pathname), traceback.format_exc()))

    def process_IN_CREATE(self, event):
        self.logger.debug("create of {0}".format(event.pathname))
        if not is_valid_raw(event.pathname):
            return
        self.handle_incoming(event)

    def process_IN_DELETE(self, event):
        self.logger.debug("delete of {0}".format(event.pathname))
        if not is_valid_raw(event.pathname):
            return
        self.handle_outgoing(event)

    def process_IN_MOVED_TO(self, event):
        self.logger.debug("moved to of {0}".format(event.pathname))
        if not is_valid_raw(event.pathname):
            return
        return self.handle_incoming(event)

    def process_IN_MOVED_FROM(self, event):
        self.logger.debug("moved from of {0}".format(event.pathname))
        if not is_valid_raw(event.pathname):
            return
        self.handle_outgoing(event)


def main(args):
    if args.systemd is True and (journal is None or daemon is None):
        sys.exit("Couldn't import systemd Python module!")

    if args.debug is True:
        logging.getLogger('').setLevel(logging.DEBUG)
    else:
        logging.getLogger('').setLevel(logging.INFO)

    if args.systemd is True:
        log_handler = journal.JournalHandler(SYSLOG_IDENTIFIER='scan-updater-service')
    else:
        log_handler = logging.StreamHandler()
        log_handler.setFormatter(logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s'))

    logging.getLogger('').addHandler(log_handler)

    logger = lg('scan-updater-service')
    logger.info(u"Started with command line {}".format(repr(sys.argv)))
    if args.debug:
        logger.info(u"Debug logging enabled.")

    if set_proc_name('scan-updater-service') is False:
        logger.warn("Couldn't set process name!")

    volume_dir_abs_path = os.path.abspath(os.path.normpath(args.volume_dir))
    if os.path.exists(volume_dir_abs_path) is False:
        message = u"Volume dir at '{}' doesn't exist!".format(volume_dir_abs_path)
        logger.error(message)
        print(message, file=sys.stderr)
        if args.systemd is True:
            daemon.notify(u"ERRNO={}".format(posix.EX_NOINPUT))
        sys.exit(posix.EX_NOINPUT) # EX_NOINPUT is 66, see sysexits.h for details

    # See https://github.com/seb-m/pyinotify/wiki/Tutorial for tutorial.
    watch_manager = pyinotify.WatchManager()
    inotify_mask = pyinotify.IN_DELETE | pyinotify.IN_CREATE | \
                   pyinotify.IN_MOVED_FROM | pyinotify.IN_MOVED_TO

    notifier = pyinotify.ThreadedNotifier(watch_manager,
                                          InotifyEventHandler(args.db_file, args.volume_dir))
    wdd = watch_manager.add_watch(
        os.path.abspath(os.path.normpath(args.volume_dir)), inotify_mask, rec=True)
    notifier.start()

    logger.info(u"Time-triggered update interval is {0} s.".format(args.update_interval))

    def started_callback():
        if args.systemd is True:
            daemon.notify("READY=1")

    while True:
        try:
            updater_loop(os.path.abspath(os.path.normpath(args.volume_dir)), args.db_file,
                         interval_secs=args.update_interval, started_callback=started_callback)
        except KeyboardInterrupt, kbi:
            notifier.stop()
            raise kbi
        except:
            logger.error(
                u"Encountered exception: {0}".format(traceback.format_exc()))
            logger.info("Sleeping for 10 seconds before restarting updater")
            time.sleep(10)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--db-file', dest='db_file', required=True,
                        help='scan database file path')
    parser.add_argument('--volume-dir', dest='volume_dir', required=True,
                        help='directory root to recursively include in the scan database')
    parser.add_argument('--update-interval', dest='update_interval',
                        default=600, required=False,
                        help='interval between timed full update runs in seconds')
    parser.add_argument('--systemd', dest='systemd', action='store_true', default=False)
    parser.add_argument('-d', '--debug', dest='debug', action='store_true', default=False)
    args = parser.parse_args()
    main(args)
