from time import sleep

import requests


COLLECTION_FULL = 'full'
COLLECTION_IN_STOCK = 'in_stock'
COLLECTION_TYPES = [
    COLLECTION_FULL,
    COLLECTION_IN_STOCK,
]

SIM_THRESHOLD_MIN = 0.7
HIT_LIMIT_MAX = 1000


EXACT_SEARCH_CHUNK_SIZE = 1000
PRICES_CHUNK_SIZE_MAX = 100
RATE_LIMIT_MAX_WAIT_TIME_DEFAULT = 180


SEARCH_TYPE_EXACT = 'exact'
SEARCH_TYPE_SIM = 'sim'
SEARCH_TYPE_SSS = 'sss'



class SearchBase:

    def __init__(self, progress_logger=None, general_logger=None):
        if general_logger:
            self.general_logger = general_logger

        if progress_logger:
            self.progress_logger = progress_logger

    def progress_logger(self, *args, **kwargs):
        """Function that is called after each API call to track progress
        3 keyword arguments are passed:
            status_code: the response.status_code
            chunk_size: number of queries in the API call
            ignore_errors: whether not to consider non-OK responses failures
        """
        pass

    def general_logger(self, *args, **kwargs):
        """Function that is called after each API call for debugging
        1 positional argument is passed: the debug text
        """
        pass

    def rate_limited_post_request(self, request_kwargs, chunk_size):
        """Send a post request that takes into account rate limits"""
        max_wait_time = self.rate_limit_max_wait_time or RATE_LIMIT_MAX_WAIT_TIME_DEFAULT
        response = requests.post(**request_kwargs)
        if response.status_code == 429:
            wait_time = response.headers.get('Retry-After', 0)
            wait_time = int(wait_time)
            if wait_time and (wait_time <= max_wait_time):
                # the time wait time comes primarily from the time to wait between requests
                chunk_size_by_wait_time = chunk_size / wait_time
                for _ in range(wait_time):
                    sleep(1)
                    self.progress_logger(
                        status_code=response.status_code,
                        chunk_size=chunk_size_by_wait_time,
                        ignore_errors=True
                    )
                response = requests.post(**request_kwargs)
        self.progress_logger(
            status_code=response.status_code,
            chunk_size=chunk_size,
        )
        return response


class Search(SearchBase):
    def __init__(
        self,
        queries,
        base_url,
        api_token=None,
        collection=COLLECTION_FULL,
        search_type=SEARCH_TYPE_EXACT,
        sim_threshold=SIM_THRESHOLD_MIN,
        hit_limit=10,
        rate_limit_max_wait_time=None,
        *args, **kwargs
    ):
        """Run a search using the Mcule API
        Parameters
        ----------
        queries: list
            List or queries (as strings)
        base_url: str
            Base URL
        api_token: str
            API token
        collection: str
            The collection to search in (full or in_stock)
        search_type: str
            Accepted values exact/sim/sss
        sim_threshold: float
            Similarity threshold
        hit_limit: int
            Hit limit for sim/sss search
        rate_limit_max_wait_time: int
            If the time obtained from the Retry-After response header is
            smaller than this limit, the request is retried after the
            necessary time has passed.
        """
        self.base_url = base_url
        self.api_token = api_token
        self.queries = queries
        self.collection = collection
        self.search_type = search_type
        self.sim_threshold = sim_threshold
        self.hit_limit = hit_limit
        self.rate_limit_max_wait_time = rate_limit_max_wait_time
        self.csv_headers = []
        self.exact_search_chunk_size = EXACT_SEARCH_CHUNK_SIZE
        self._unauthorized = False  # if set, no further API requests are made
        super().__init__(*args, **kwargs)

    def _get_post_kwargs(self, query):
        search_type = self.search_type

        payload = {'queries': query}
        if search_type in {SEARCH_TYPE_SIM, SEARCH_TYPE_SSS}:
            payload = {'query': query, 'limit': self.hit_limit}

        if search_type == SEARCH_TYPE_SIM:
            payload.update({'threshold': self.sim_threshold})

        payload.update({'collection': self.collection})

        headers = {}
        if self.api_token:
            headers.update({'Authorization': f'Token {self.api_token}'})

        kwargs = {
            'url': f'{self.base_url}search/{search_type}/',
            'json': payload,
            'headers': headers
        }
        return kwargs

    def _api_call(self, query):
        results = []
        if self._unauthorized:
            return query, results

        post_kwargs = self._get_post_kwargs(query)

        query_msg = query
        chunk_size = 1
        if isinstance(query, list):
            query_msg = f'{len(query)} queries'
            chunk_size = len(query)

        response = self.rate_limited_post_request(post_kwargs, chunk_size)

        if response.ok:
            results = response.json()
            results = results.get('results', [])
            self.general_logger(f'API call done for {query_msg}. Results count: {len(results)}')
        else:
            self.general_logger(
                f'API call failed for {query_msg} | status code: {response.status_code} | content: {response.content}'
            )
            if response.status_code == requests.codes.unauthorized:
                self._unauthorized = True

        return query, results

    def _process_results(self, search_results, query=None):
        search_type = self.search_type
        ret = []
        for result in search_results:
            transformed_output = {'input_index': result['input_index']} if search_type == SEARCH_TYPE_EXACT else {}
            transformed_output.update({
                'hit_smiles': result['smiles'],
                'hit_mcule_id': result['mcule_id'],
                'hit_url': result['url'],
                'query': query or result['query']
            })
            if search_type == SEARCH_TYPE_EXACT:
                transformed_output.update({'hit_matching_level': result['matching_level_display']})
            elif search_type == SEARCH_TYPE_SIM:
                transformed_output.update({'hit_tanimoto': result['sim']})
            ret.append(transformed_output)
        return ret

    def _exact_search(self):
        queries = self.queries
        num_queries = len(queries)
        results = []
        chunk = []
        chunk_idx = 0
        for idx, query in enumerate(self.queries, 1):
            chunk.append(query)
            if len(chunk) == self.exact_search_chunk_size or idx == num_queries:
                _, chunk_results = self._api_call(chunk)
                chunk_results = self._process_results(chunk_results)
                # update the input index
                for result in chunk_results:
                    input_index = result['input_index'] + chunk_idx * self.exact_search_chunk_size
                    result['input_index'] = input_index
                results.extend(chunk_results)
                chunk = []
                chunk_idx += 1
        return results

    def execute(self):
        search_type = self.search_type
        results = []
        if search_type == SEARCH_TYPE_EXACT:
            results = self._exact_search()
        else:
            for query in self.queries:
                query, results_part = self._api_call(query)
                results_part = self._process_results(results_part, query)
                results.extend(results_part)

        headers = list(results[0]) if results else []
        self.csv_headers.extend(headers)
        return self.csv_headers, results


class CompoundDetails(SearchBase):
    def __init__(
        self,
        base_url,
        api_token,
        input_data,
        mcule_id_colname='hit_mcule_id',
        amounts=None,
        csv_headers=None,
        rate_limit_max_wait_time=None,
        *args, **kwargs
    ):
        """
        Parameters
        ----------
        base_url: str
            Base URL
        api_token: str
            API token
        input_data: list
            List of input entries to get the details for. The entries can be either
            MculeIDs or dicts containing the MculeIDs and other associated data.
        mcule_id_colname: str
            Column name of the MculeID to fetch the data for if the input_data list
            contains dicts.
        amounts: list
            The amounts to fetch the prices for.
        csv_headers: list
            List of dict keys in order to keep in the output.
        """
        self.base_url = base_url
        self.api_token = api_token

        input_is_dict = False
        if input_data:
            first_entry = input_data[0]
            input_is_dict = isinstance(first_entry, dict)

        self.input_data = input_data if input_is_dict else [{'hit_mcule_id': entry} for entry in input_data]
        self.mcule_ids = {entry[mcule_id_colname] for entry in input_data} if input_is_dict else input_data
        self.amounts = amounts or [1]
        self.mcule_id_colname = mcule_id_colname
        self.api_results = {}
        self.csv_headers = [] if not csv_headers else csv_headers
        self.rate_limit_max_wait_time = rate_limit_max_wait_time
        self._set_csv_headers()
        self._unauthorized = False

        # Max MculeIDs allowed in a single API call
        self.chunk_size = PRICES_CHUNK_SIZE_MAX

        self.url = f'{self.base_url}compounds/'
        self.headers = {}
        if self.api_token:
            self.headers.update({
                'Authorization': f'Token {self.api_token}',
            })
        super().__init__(*args, **kwargs)

    def _api_call(self, chunk):
        results = []
        if self._unauthorized:
            return results

        chunk_size = len(chunk)
        payload = {
            'availability': True,
            'components': False,
            'mcule_ids': chunk,
            'price_amounts': self.amounts
        }
        post_kwargs = {
            'url': self.url,
            'json': payload,
            'headers': self.headers
        }
        response = self.rate_limited_post_request(post_kwargs, chunk_size)

        if response.ok:
            results = response.json()
            results = results.get('results', [])
            self.general_logger(
                f'API call done for {chunk_size} queries. Results count: {len(results)}.'
            )
        else:
            self.general_logger(
                f'API call faled for {chunk_size} queries. | '
                f'status code: {response.status_code} | content: {response.content}.'
            )
            if response.status_code == requests.codes.unauthorized:
                self._unauthorized = True

        return results

    def _get_prices(self):
        mcule_ids = self.mcule_ids
        num_compounds = len(mcule_ids)
        chunk = []
        for idx, mcule_id in enumerate(mcule_ids, 1):
            chunk.append(mcule_id)
            if len(chunk) == self.chunk_size or idx == num_compounds:
                for result in self._api_call(chunk):
                    mcule_id = result['mcule_id']
                    self.api_results[mcule_id] = result
                chunk = []

    def _set_csv_headers(self):
        # compound identifier columns
        for col in ['hit_mcule_id', 'hit_smiles']:
            if col not in self.csv_headers:
                self.csv_headers.append(col)
        # availability columns
        self.csv_headers += [
            'hit_availability_type',
            'hit_confirmed_amount',
            'hit_confirmed_amount_unit',
        ]
        # pricing columns
        for idx in range(1, len(self.amounts) + 1):
            pricing_cols = [
                f'hit_amount_{idx}',
                f'hit_unit_{idx}',
                f'hit_price_{idx}',
                f'hit_currency_{idx}',
                f'hit_delivery_time_wdays_{idx}',
                f'hit_purity%_{idx}',
            ]
            self.csv_headers += pricing_cols

    @staticmethod
    def _parse_pricing_info(pricing_info):
        ret = {
            'hit_mcule_id': pricing_info['mcule_id'],
            'hit_smiles': pricing_info['smiles'],
        }
        availability_info = pricing_info['availability']
        ret.update({
            'hit_availability_type': availability_info['availability_type'],
            'hit_confirmed_amount': availability_info['confirmed_amount'],
            'hit_confirmed_amount_unit': availability_info['confirmed_amount_unit'],
        })

        for idx, entry in enumerate(pricing_info['best_prices'], 1):
            ret.update({
                f'hit_amount_{idx}': entry['amount'],
                f'hit_unit_{idx}': entry['unit'],
                f'hit_price_{idx}': entry['price'],
                f'hit_currency_{idx}': entry['currency'],
                f'hit_delivery_time_wdays_{idx}': entry['delivery_time_working_days'],
                f'hit_purity%_{idx}': entry.get('purity'),
            })
        return ret

    def execute(self):
        self._get_prices()
        ret = []
        for entry in self.input_data:
            mcule_id = entry[self.mcule_id_colname]
            pricing_info = self.api_results.get(mcule_id)
            if pricing_info:
                entry.update(self._parse_pricing_info(pricing_info))
            ret.append(entry)
        return self.csv_headers, ret


# Worfklow-related functions
def workflow(
    base_url,
    api_token,
    queries,
    search_type=None,
    search_collection=COLLECTION_FULL,
    sim_threshold=SIM_THRESHOLD_MIN,
    hit_limit=10,
    price_amounts=None,
    skip_if_no_price=False,
    max_price=None,
    progress_logger=None,
    general_logger=None,
):
    """High level function to access the Mcule API
    Parameters
    ----------
    base_url: str
        The API base URL.
    api_token: str
        Your API token.
    queries: list of str
        List of queries accepted by the API. An optional identifier can
        be placed after the molecular identifier, seperated by a tab character.
    search_type: str
        The following values are accepted: exact, sim, sss, None. If the
        value is None, only prices are fetched and no search is carried out
        beforehand. In this case the queries must be MculeIDs.
    search_collection: str
        Accepted values: full, in_stock
    sim_threshold: float
        Similarity threshold.
    hit_limit: int
        Hit limit for sim and sss searches.
    price_amounts: list of int
        List of amounts for which the prices are fetched. If no values
        are specified, only a search is carried out.
    skip_if_no_price: bool
        If True, the compounds with no price information will be excluded
        from the output.
    max_price: float
        If specified, compounds with a price greater than max_price will
        be excluded from the output. The price field corresponds to the
        first amount specified in price_amounts.
    progress_logger: func
        Function that is called after each API call, with the keyword arguments
        `status_code` and `chunk_size`. The goal of this implementation is
        to have access to the API progress outside of the class and to build
        outside logic on the status_codes.
    general_logger: func
        Function that is called is for debug purposes with a single positional
        argument (the debug text).

    Returns
    -------
    tuple
        List of column names, list of results.
    """
    # validate options
    if not search_type and not price_amounts:
        raise Exception('You must either specify a search type or price_amounts.')
    if not price_amounts and skip_if_no_price:
        raise Exception('You must fetch prices to skip entries with no price.')

    # do a search before fetching prices
    if search_type is not None:
        search = Search(
            base_url=base_url,
            api_token=api_token,
            queries=queries,
            collection=search_collection,
            search_type=search_type,
            sim_threshold=sim_threshold,
            hit_limit=hit_limit,
            progress_logger=progress_logger,
            general_logger=general_logger,
        )
        csv_headers, results = search.execute()

    # only fetch prices
    else:
        results = queries
        csv_headers = None

    if price_amounts:
        get_details = CompoundDetails(
            base_url=base_url,
            api_token=api_token,
            input_data=results,
            amounts=price_amounts,
            csv_headers=csv_headers,
            progress_logger=progress_logger,
            general_logger=general_logger,
        )
        csv_headers, results = get_details.execute()

    def _price_filter(entry):
        ret = True
        price = entry.get('hit_price_1')
        if skip_if_no_price and not price:
            ret = False
        if price and max_price and max_price < price:
            ret = False
        return ret

    results = list(filter(_price_filter, results))

    return csv_headers, results
