import os
import re
import json

import app
import PythonWidgets
import QueryInterfaceWidgets
import requests
import logging

from Mcule_tools import (
    Search,
    CompoundDetails,
    SEARCH_TYPE_EXACT,
    SEARCH_TYPE_SIM,
    SEARCH_TYPE_SSS,
    COLLECTION_FULL,
    COLLECTION_IN_STOCK,
)

VERSION = f'{1:.1f}'
BASE_URL = 'https://mcule.com'
API_BASE_URL = f'{BASE_URL}/api/v1/'
MCULE_AUTH_PATH_DEPRECATED = 'mcule_auth.txt'
MCULE_SETTINGS_PATH = 'mcule_settings.json'

MCULE_SETTING_API_TOKEN = 'API_TOKEN'
MCULE_SETTING_RATE_LIMIT_MAX_WAIT_TIME = 'RATE_LIMIT_MAX_WAIT_TIME_SECONDS'

MCULE_SETTINGS_DEFAULT = {
    MCULE_SETTING_RATE_LIMIT_MAX_WAIT_TIME: 180,
}


PARAM_SEARCH_TYPE = 'Search type'
PARAM_COLLECTION = 'Search collection'
PARAM_SIM_THRESHOLD = 'Similarity threshold'
PARAM_LIMIT = 'Hit limit'
PARAM_AMOUNT = 'Price amount (mg)'
# param/display_name, choices
FORM_OPTIONS_DEFAULT = {
    PARAM_SEARCH_TYPE: ['exact', 'similarity', 'substructure'],
    PARAM_COLLECTION: [COLLECTION_FULL, COLLECTION_IN_STOCK],
    PARAM_SIM_THRESHOLD: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1],
    PARAM_LIMIT: [5, 10, 100, 1000],
    PARAM_AMOUNT: [None, 1, 5, 10, 50, 100, 1000],
}

LOG_FMT = '{date_and_time:s} - {logging_level:s} - {message:s}'
LOG_DATE_AND_TIME_FMT = '{weekday:s}, {month:s} {day:s}, {year:s} at {hour:s}:{minute:s}:{second:s} {period:s}'
Mcule_logger = logging.getLogger(__name__)

if not Mcule_logger.handlers:

    Mcule_logfile_formatter = \
        logging.Formatter(fmt=LOG_FMT.format(date_and_time='%(asctime)s', logging_level='%(levelname)s', message='%(message)s'),
                          datefmt=LOG_DATE_AND_TIME_FMT.format(weekday='%A', month='%B', day='%d', year='%Y',
                                                               hour='%I', minute='%M', second='%S', period='%p'))

    Mcule_logfile_handler = logging.FileHandler(os.path.join(os.getcwd(), 'Mcule.log'))
    Mcule_logfile_handler.setLevel(logging.DEBUG)
    Mcule_logfile_handler.setFormatter(Mcule_logfile_formatter)

    Mcule_logger.setLevel(logging.DEBUG)
    Mcule_logger.addHandler(Mcule_logfile_handler)


def _prompt_api_token(appui):
    text = (
        'To obtain a token please create an account on mcule.com and reach\n'
        'out to Mcule staff to obtain a token at support@mcule.com.'
    )
    token = appui.query('Provide an API token for authentication', text, '')
    return token


def _prompt_rate_limit_max_wait_time(appui):
    text = (
        'Set the maximum time you are willing to wait (in seconds) between API\n'
        'requests in case a HTTP 429 (too many requests) response is received.\n'
        'If the wait time is smaller than this limit, the request is retries after\n'
        'the necessary time (based on the Retry-After response header) has passed.'
    )
    wait_time = appui.query('Provide rate limit max wait time (seconds)', text, '')
    return wait_time


def _load_settings():
    settings = {}
    if os.path.isfile(MCULE_SETTINGS_PATH):
        settings = json.load(open(MCULE_SETTINGS_PATH))
    return settings

def _save_settings(settings):
    json.dump(settings, open(MCULE_SETTINGS_PATH, 'w'))


def _get_setting(key):
    return _load_settings().get(key)


def _save_setting(key, value=None):
    settings = _load_settings()
    settings[key] = value
    settings = {k: v for k, v in settings.items() if v is not None}
    _save_settings(settings)


def _handle_backwards_compatibility():
    # Skip if there is already a settings file
    if os.path.isfile(MCULE_SETTINGS_PATH):
        return
    # If the user has an old auth file, get the API token from it,
    # and use it to create a default settings file
    api_token = None
    if os.path.isfile(MCULE_AUTH_PATH_DEPRECATED):
        api_token = open(MCULE_AUTH_PATH_DEPRECATED).read()
        os.remove(MCULE_AUTH_PATH_DEPRECATED)
    default_settings = MCULE_SETTINGS_DEFAULT
    if api_token:
        default_settings.update({MCULE_SETTING_API_TOKEN: api_token})
        _save_settings(default_settings)


def _get_workflow_params(options):
    cb = PythonWidgets.ComboBoxSelector()
    cb.setWindowTitle('Choose options')

    text = []
    if PARAM_LIMIT in options:
        text.append('"Hit limit" is only applied to similarity and substructure searches.')
    if PARAM_SIM_THRESHOLD in options:
        text.append('"Similarity threshold" is only applied to similarity search.')
    if PARAM_AMOUNT in options:
        if None in options[PARAM_AMOUNT]:
            text.append('If "Price amount (mg)" is None, only a search is carried out.')
    text = '\n'.join(text).strip()
    cb.setText(text)

    for idx, (opt_name, choices) in enumerate(options.items()):
        choices_as_str = [str(choice) if not isinstance(choice, str) else choice for choice in choices]
        if idx == 0:
            cb.setOptions(opt_name, choices_as_str)
        else:
            cb.addOptions(opt_name, choices_as_str)

    def type_convert(value):
        if value == 'None':
            return None
        try:
            return int(value)
        except ValueError:
            pass
        try:
            return float(value)
        except ValueError:
            pass
        return value

    if cb.run():
        search_params = {}
        for idx, param in enumerate(options):
            value = type_convert(cb.getSelectedOption(idx))
            if param == PARAM_SEARCH_TYPE:
                value = {
                    'exact': SEARCH_TYPE_EXACT,
                    'similarity': SEARCH_TYPE_SIM,
                    'substructure': SEARCH_TYPE_SSS
                }[value]
            search_params[param] = value
        return search_params


def _create_dataset_from_api_results(appui, columns, results):
    str_cols = {
        'hit_availability_type',
        'hit_confirmed_amount_unit',
        'hit_currency_1',
        'hit_matching_level',
        'hit_mcule_id',
        'hit_unit_1',
    }
    url_cols = {
        'hit_url',
    }
    num_cols = {
        'hit_confirmed_amount',
        'hit_amount_1',
        'hit_price_1',
        'hit_delivery_time_wdays_1',
        'hit_purity%_1',
        'hit_tanimoto',
    }
    mol_cols = {
        'query',
        'hit_smiles'
    }

    # transform result data into dataset representation
    ds_factory = app.Dataset
    dataset_columns = {}
    for column in columns:
        # molecule
        if column in mol_cols:
            column_data = (
                ds_factory.meta(column, app.mol),
                [
                    ds_factory.molecule(value)
                    if (value := result.get(column))
                    else ds_factory.createInvalidEntry(app.mol)
                    for result in results
                ]
            )
        # number
        elif column in num_cols:
            column_data = (
                ds_factory.createNumberHeader(column),
                [
                    ds_factory.cont(value, 0.0)
                    if (value := result.get(column))
                    else ds_factory.createInvalidEntry(app.cont)
                    for result in results
                ]
            )
        # string
        elif column in str_cols:
            column_data = (
                ds_factory.meta(column, app.string),
                [
                    ds_factory.string(value)
                    if (value := result.get(column))
                    else ds_factory.createInvalidEntry(app.string)
                    for result in results
                ]
            )
        elif column in url_cols:
            column_data = (
                ds_factory.meta(column, app.string),
                [
                    ds_factory.string(f'<a href="{value}">{value}</a>')
                    if (value := result.get(column))
                    else ds_factory.createInvalidEntry(app.string)
                    for result in results
                ]
            )
        # ignored columns
        else:
            continue

        dataset_columns[column] = column_data

    ds = ds_factory([data[0] for data in dataset_columns.values()])
    for idx, data in enumerate(dataset_columns.values()):
        ds.setColumn(idx, data[1])
    appui.exportDS('Mcule results', ds.pack())


def _display_error(msg):
    Mcule_logger.critical(msg)
    warning_dlg = PythonWidgets.Warning()
    warning_dlg.setMessage(msg)
    warning_dlg.run()


def _display_text(title, label, text):
    td = PythonWidgets.TextDisplay()
    td.setWindowTitle(title)
    td.setLabel(label)
    td.setText(text)
    td.exec()


class ProgressTracker:

    def __init__(self, appui, title, message, maximum):
        """Display progress bar and collect errors

        Parameters
        ----------
        appui: appui
            The Stardrop appui object.
        title: str
            Progress bar title.
        message: str
            Progress bar message.
        maximum: int
            The number of entries that are queried.
        """
        self.appui = appui
        self.bar = PythonWidgets.SimpleProgress()
        self.bar.modal = True
        self.bar.setWindowTitle(title)
        self.bar.setMessage(message)
        self.maximum = maximum
        self.bar.setMaximum(100)
        self.progress = 0
        self.bar.setProgress(self.progress)
        self.bar.show()
        self.appui.updateUI()

        self.error_status_codes = set()
        self.error_failed_entry_count = 0

    def delete(self):
        self.bar.hide()

    def update(self, status_code, chunk_size, ignore_errors=False):
        if not ignore_errors and (status_code != requests.codes.ok):
            self.error_status_codes.add(status_code)
            self.error_failed_entry_count += chunk_size
        self.progress += chunk_size
        # cap progress at max
        self.progress = min(self.progress, self.maximum)
        progress_percent_display = int((self.progress / self.maximum) * 100)
        self.bar.setProgress(progress_percent_display)
        self.appui.updateUI()


def _get_valid_mcule_ids(mcule_ids):
    valid_ids = {
        mcule_id for mcule_id in mcule_ids
        if re.match(r'^MCULE-[\d]+$', mcule_id)  # Regex checks to make sure Mcule ID fits format (e.g., MCULE-4899719484 for Benzene) # noqa: E501
    }
    error_ids = set(mcule_ids) - set(valid_ids)
    if error_ids:
        _display_error(f'Skipping {len(error_ids)} invalid Mcule IDs')
    if not valid_ids:
        _display_error('No valid MculeIDs')
    return list(valid_ids)


def _run_workflow(appui, mode):
    _handle_backwards_compatibility()
    workflow_param_options = FORM_OPTIONS_DEFAULT
    perform_search = True

    # get the API token from settings or prompt the user for one
    api_token = _get_setting(MCULE_SETTING_API_TOKEN)
    if not api_token:
        api_token = _prompt_api_token(appui)
        if not api_token:
            _display_error('No token provided')
            return
        _save_setting(MCULE_SETTING_API_TOKEN, api_token)
        appui.show(f'{MCULE_SETTING_API_TOKEN} updated: {api_token}')

    # get the search limit max wait time
    max_wait_time = _get_setting(MCULE_SETTING_RATE_LIMIT_MAX_WAIT_TIME)
    max_wait_time = int(max_wait_time) if max_wait_time is not None else None

    # get the input
    if mode == 'selection':
        queries = appui.getCurrentMolecules()
        if not queries:
            _display_error('No entries selected')
            return
    elif mode == 'draw':
        # request input structure
        dlg = PythonWidgets.MoleculeSketcher()
        if dlg.run():
            query = dlg.getMoleculeAsSMILES()
            queries = [query]
        else:
            return
        if not query:
            _display_error('Empty structure')
            return
    elif mode == 'details':
        perform_search = False
        workflow_param_options = {
            PARAM_AMOUNT: [opt for opt in FORM_OPTIONS_DEFAULT[PARAM_AMOUNT] if opt is not None]
            }

        dlg = QueryInterfaceWidgets.IDQueryDialog()
        if dlg.run():
            queries = dlg.selectedQuery()
        else:
            return
        if not queries:
            _display_error('No Mcule IDs specified')
            return
        # check if valid MCULE ID
        queries = _get_valid_mcule_ids(queries)
        if not queries:
            return

    # request search params
    workflow_params = _get_workflow_params(workflow_param_options)
    if not workflow_params:
        return

    price_amount = workflow_params[PARAM_AMOUNT]
    price_amounts = [price_amount] if price_amount else None
    errors = []

    if perform_search:
        num_queries = len(queries)
        progress_bar = ProgressTracker(
            appui=appui,
            title='Mcule Search',
            message='Searching using the Mcule API',
            maximum=num_queries,
        )
        search = Search(
            base_url=API_BASE_URL,
            api_token=api_token,
            queries=queries,
            collection=workflow_params[PARAM_COLLECTION],
            search_type=workflow_params[PARAM_SEARCH_TYPE],
            sim_threshold=workflow_params[PARAM_SIM_THRESHOLD],
            hit_limit=workflow_params[PARAM_LIMIT],
            progress_logger=progress_bar.update,
            general_logger=Mcule_logger.debug,
            rate_limit_max_wait_time=max_wait_time,
        )
        headers, results = search.execute()
        errors.append(('Search', progress_bar.error_status_codes, progress_bar.error_failed_entry_count))
        progress_bar.delete()
    else:
        results = queries
        headers = None

    if price_amounts:
        num_queries = len(results)
        progress_bar = ProgressTracker(
            appui=appui,
            title='Mcule CompoundDetails',
            message='Fetching compound details using the Mcule API',
            maximum=num_queries,
        )
        get_details = CompoundDetails(
            base_url=API_BASE_URL,
            api_token=api_token,
            input_data=results,
            amounts=price_amounts,
            csv_headers=headers,
            progress_logger=progress_bar.update,
            general_logger=Mcule_logger.debug,
            rate_limit_max_wait_time=max_wait_time,
        )
        headers, results = get_details.execute()
        errors.append(('CompoundDetails', progress_bar.error_status_codes, progress_bar.error_failed_entry_count))
        progress_bar.delete()

    if results:
        _create_dataset_from_api_results(appui, headers, results)
    else:
        warning_dlg = PythonWidgets.Warning()
        warning_dlg.setMessage('Could not find any molecules')
        warning_dlg.run()

    error_msgs = []
    for place, status_codes, failed_entry_count in errors:
        if requests.codes.unauthorized in status_codes:
            error_msgs = ['Unauthorized, please provide a valid API token.']
            break

        if not failed_entry_count:
            continue

        status_codes = ', '.join(str(sc) for sc in status_codes)
        error_msgs.append(
            f'{place} request failed for {failed_entry_count} entries. Status code(s): {status_codes}'
        )

    if error_msgs:
        _display_error('\n'.join(msg for msg in error_msgs))


def structure_query(appui):
    """Search by structure drawing"""
    Mcule_logger.info(f'Running Structure Query for Mcule Plugin v.{VERSION:s}')
    _run_workflow(appui, 'draw')


def selection_query(appui):
    """Search by selection"""
    Mcule_logger.info(f'Running Selection Query for Mcule Plugin v.{VERSION:s}')
    _run_workflow(appui, 'selection')


def id_query(appui):
    """Get details for Mcule IDs"""
    Mcule_logger.info(f'Running ID Query for Mcule Plugin v.{VERSION:s}')
    _run_workflow(appui, 'details')


def request_quote(appui):
    """Display a URL that the user can copy into their browser to request a quote"""

    Mcule_logger.info(f'Requesting quote for Mcule Plugin v.{VERSION:s}')

    dataset = app.Dataset()
    dataset.parse(appui.importDS(''))

    col_count = dataset.columnSz()
    col_names = {dataset.columnName(idx): 0 for idx in range(col_count)}

    cp = PythonWidgets.ColumnPicker()
    cp.setColumns(col_names)
    cp.setWindowTitle('Select the Mcule ID column')

    col_name = None
    if cp.run():
        col_name = cp.getColumn()
    if not col_name:
        return

    col_idx = dataset.columnIndex(col_name)
    queries = {
        dataset.stringval(col_idx, row_idx)
        for row_idx in range(dataset.rowSz())
    }
    queries = _get_valid_mcule_ids(queries)
    if not queries:
        return

    queries = ','.join(queries)
    _display_text(
        'Generate a quote on mcule.com',
        'To generate a quote copy the following URL into your browser:',
        f'{BASE_URL}/search/multi/?queries={queries}'
    )


def update_api_token(appui):
    """Update the API token"""
    Mcule_logger.info(f'Updating API Token for Mcule Plugin v.{VERSION:s}')
    token = _prompt_api_token(appui)
    if not token:
        appui.show('No updated token provided, old token left unchanged.')
        return
    _save_setting(MCULE_SETTING_API_TOKEN, token)
    appui.show(f'{MCULE_SETTING_API_TOKEN} updated: {token}')


def update_rate_limit_max_wait_time(appui):
    """Update the rate limit max wait time"""
    Mcule_logger.info(f'Updating rate limit max wait time v.{VERSION:s}')
    wait_time = _prompt_rate_limit_max_wait_time(appui)
    if wait_time is None:
        appui.show('No wait time provided, old value left unchanged.')
        return
    elif not wait_time.isnumeric():
        appui.show('Invalid wait time, old value left unchanged.')
        return
    _save_setting(MCULE_SETTING_RATE_LIMIT_MAX_WAIT_TIME, wait_time)
    appui.show(f'{MCULE_SETTING_RATE_LIMIT_MAX_WAIT_TIME} updated: {wait_time}')


def show_settings(appui):
    settings = _load_settings()
    if not settings:
        appui.show('There are no settings yet.')
    else:
        settings_str = '<br>'.join(f'{k}: {v}' for k, v in settings.items())
        _display_text(
            'Show settings',
            '',
            settings_str
        )


def get_stardrop_definitions():
    definitions = [
        {
            'script_name': 'Mcule/Query from selection',
            'callback': selection_query
        },
        {
            'script_name': 'Mcule/Draw molecule',
            'callback': structure_query
        },
        {
            'script_name': 'Mcule/Price details for Mcule ID list',
            'callback': id_query
        },
        {
            'script_name': 'Mcule/Request quote for Mcule ID column',
            'callback': request_quote
        },
        {
            'script_name': 'Mcule/Settings/Update API token',
            'callback': update_api_token
        },
        {
            'script_name': 'Mcule/Settings/Set rate limit max wait time',
            'callback': update_rate_limit_max_wait_time
        },
        {
            'script_name': 'Mcule/Settings/Show current settings',
            'callback': show_settings
        },
    ]
    return definitions
