#!/usr/bin/env python3
import argparse
import csv
import os
import random
import requests
from PIL import Image
import imagehash
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import logging


def download_images(path, url_column, id_column, output_folder, delimiter, num_samples=10000):
    """
    Downloads the images from the image urls in the datasets
    """
    os.makedirs(output_folder, exist_ok=True)

    session = requests.Session()

    # define maximum number of retries per image
    retries = Retry(
        total=2,
        backoff_factor=0.3,
        status_forcelist=[500, 502, 503, 504]
    )

    adapter = HTTPAdapter(max_retries=retries)
    session.mount("http://", adapter)
    session.mount("https://", adapter)

    headers = {
        "user-agent": "curl/7.68.0",
        "accept": "*/*"
    }

    # download placeholder images for comparison
    if path == 'ifnd.csv':
        placeholder_row = {
            'Image': 'https://images.indianexpress.com/2020/12/express-photo-38-11.jpg?resize=450,250', 'id': '3199'}
    else:
        placeholder_row = {'id': 'ckm2j5y',
                           'image_url': 'http://i.imgur.com/KC2WyYB%2ejpg'}

    with open(path, 'r', encoding="ISO-8859-1") as file:
        reader = list(csv.DictReader(file, delimiter=delimiter))
        random.shuffle(reader)
        reader.insert(0, placeholder_row)
        count = 0
        for i, row in enumerate(reader):
            if url_column in row and row[url_column]:
                url = row[url_column]
                img_id = row[id_column]

                file_name = os.path.join(output_folder, f'{img_id}.jpg')

                if os.path.exists(file_name):
                    continue
                try:

                    response = session.get(
                        url, stream=True, headers=headers, timeout=(1, 3))
                    response.raise_for_status()

                    # write the file to folder
                    with open(file_name, 'wb') as img_file:
                        for chunk in response.iter_content(1024):
                            img_file.write(chunk)
                    count += 1
                    logging.debug(f'Downloaded: {file_name}')
                except requests.RequestException as e:
                    logging.warning(f'Failed to download {url}: {e}')

            if count % 100 == 0 and count != 0:
                logging.info(f"Downloaded {count} images succesfully!")

            if count >= num_samples:
                break


def update_paths(path, url_column, id_column, title_column, output_folder, updated_path, delimiter, placeholder_image):
    """
    Replaces url links to filesystem paths of properly downloaded images.
    Ommits images that are placeholder imgur image, i.e. when the image is no longer available.
    Also omits images with incorrectly encoded titles.
    """
    # imgur placeholder image that is verified to be downloaded when the image is missing
    hash0 = imagehash.average_hash(Image.open(
        placeholder_image)) if placeholder_image is not None and os.path.exists(placeholder_image) else None
    updated_rows = []

    with open(path, 'r', encoding='ISO-8859-1') as file:
        reader = csv.DictReader(file, delimiter=delimiter)
        fieldnames = reader.fieldnames
        rows = list(reader)
        if len(rows) == 0:
            logging.error(f"Unable to load any sample from the input {path}")
        for i, row in enumerate(rows):
            if url_column in row and row[url_column]:
                img_id = row[id_column]
                image_path = os.path.join(output_folder, f'{img_id}.jpg')
                if os.path.exists(image_path):
                    try:
                        # verify the text is well encoded
                        text = row[title_column]
                        encoded_decoded_text = bytes(
                            text, 'iso-8859-1').decode('utf-8')
                        if text != encoded_decoded_text:
                            raise Exception(
                                f"image has label encoding error: {text}")

                        # try if the image can be opened
                        im = Image.open(image_path)

                        if hash0 is not None:
                            hash1 = imagehash.average_hash(im)
                            cutoff = 5  # maximum bits that could be different between the hashes

                            if hash0 - hash1 < cutoff:
                                raise Exception("image is placeholder")

                        row[url_column] = image_path
                        updated_rows.append(row)
                    except Exception as e:
                        logging.warning(f"failed at {image_path} {str(e)}")

    if len(updated_rows) == 0:
        logging.error(f"Unable to load any sample from the input {path}")
    with open(updated_path, 'w', encoding='utf-8', newline='') as file:
        writer = csv.DictWriter(
            file, fieldnames=fieldnames, delimiter=delimiter)
        writer.writeheader()
        writer.writerows(updated_rows)


def main():
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')

    parser = argparse.ArgumentParser(
        description='Download images for fake news datasets',
        epilog='Recommended to download Fakeddit with the following command: curl -L "https://docs.google.com/uc?export=download&id=1Z99QrwpthioZQY2U6HElmnx8jazf7-Kv" -o fakeddit.tsv\n'
        'Recommended to download IFND from Kaggle: https://www.kaggle.com/datasets/sonalgarg174/ifnd-dataset?select=IFND.csv',
        formatter_class=argparse.RawTextHelpFormatter
    )

    parser.add_argument('--dataset', type=str, choices=['fakeddit', 'ifnd'], required=True,
                        help='Which dataset to download images for: fakeddit or ifnd')
    parser.add_argument('--input', type=str, required=True,
                        help='Input file path (TSV or CSV) with image URLs')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='Directory to save downloaded images')
    parser.add_argument('--num_samples', type=int, default=30000,
                        help='Maximum number of images to download')

    args = parser.parse_args()

    if args.dataset == 'fakeddit':
        # Download fakeddit images
        download_images(
            path=args.input,
            url_column='image_url',
            id_column='id',
            output_folder=args.output_dir,
            delimiter='\t',
            num_samples=args.num_samples
        )

        # Update paths for fakeddit
        update_paths(
            path=args.input,
            url_column='image_url',
            id_column='id',
            title_column='clean_title',
            output_folder=args.output_dir,
            updated_path=f"{args.dataset}_with_paths.tsv",
            delimiter='\t',
            placeholder_image=os.path.join(args.output_dir, "ckm2j5y.jpg")
        )
    else:
        # Download IFND images
        download_images(
            path=args.input,
            url_column='Image',
            id_column='id',
            output_folder=args.output_dir,
            delimiter=',',
            num_samples=args.num_samples
        )

        # Update paths for IFND
        update_paths(
            path=args.input,
            url_column='Image',
            id_column='id',
            title_column='Statement',
            output_folder=args.output_dir,
            updated_path=f"{args.dataset}_with_paths.csv",
            delimiter=',',
            placeholder_image=os.path.join(args.output_dir, "3199.jpg")
        )


if __name__ == '__main__':
    main()
