from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os.path
from datetime import datetime
from PIL import Image
import numpy as np
from io import BytesIO
import requests
import base64

import tensorflow as tf
from tensorflow.python.platform import gfile
import captcha_model3 as captcha

import config
import logging

IMAGE_WIDTH = config.IMAGE_WIDTH
IMAGE_HEIGHT = config.IMAGE_HEIGHT

CHAR_SETS = config.CHAR_SETS
CLASSES_NUM = config.CLASSES_NUM
CHARS_NUM = config.CHARS_NUM

FLAGS = None

log_file_path = './your_log_file.log'


def setup_logger(log_file):
    """
    Sets up a logger that writes to a specified log file.

    Args:
    log_file (str): The path to the log file.

    Returns:
    logger (logging.Logger): Configured logger instance.
    """
    # Create a logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)  # You can set the level to DEBUG, INFO, WARNING, ERROR, or CRITICAL

    # Create file handler that logs to the specified file
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.DEBUG)

    # Create a logging format
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)

    # Add the handler to the logger
    logger.addHandler(file_handler)

    return logger


def log_message(log_file, message, level="info"):
    """
    Logs a message to the specified log file.

    Args:
    log_file (str): The path to the log file.
    message (str): The message to log.
    level (str): The logging level ('info', 'debug', 'warning', 'error', 'critical'). Default is 'info'.
    """
    logger = setup_logger(log_file)

    if level.lower() == "debug":
        logger.debug(message)
    elif level.lower() == "warning":
        logger.warning(message)
    elif level.lower() == "error":
        logger.error(message)
    elif level.lower() == "critical":
        logger.critical(message)
    else:
        logger.info(message)


def one_hot_to_texts(recog_result):
    texts = []
    for i in range(recog_result.shape[0]):
        index = recog_result[i]
        texts.append(''.join([CHAR_SETS[i] for i in index]))
    return texts


def input_data(image_dir):
    batch_size = 1
    images = np.zeros([batch_size, IMAGE_HEIGHT * IMAGE_WIDTH], dtype='float32')

    if len(sys.argv) > 1:
        image = Image.open(BytesIO(base64.b64decode(sys.argv[1])))
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))
    image.save('image.jpg')
    image.close()

    input_img = np.array(image_resize, dtype='float32')
    input_img = np.multiply(input_img.flatten(), 1. / 255) - 0.5
    images[0, :] = input_img
    return images, ['image.jpeg']


def run_predict():
    #    tf.logging.set_verbosity(tf.logging.ERROR)
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        #       tf.logging.set_verbosity(tf.logging.ERROR)
        input_images, input_filenames = input_data(FLAGS.captcha_dir)
        images = tf.constant(input_images)
        logits = captcha.inference(images, keep_prob=1)
        result = captcha.output(logits)
        saver = tf.compat.v1.train.Saver()
        sess = tf.compat.v1.Session()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        recog_result = sess.run(result)
        sess.close()
        text = one_hot_to_texts(recog_result)
        print('captcha_code_is:' + text[0] + '!')


def main(_):
    run_predict()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    log_message(log_file_path, 'This is an info message.')

    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='./captcha_train',
        help='Directory where to restore checkpoint.'
    )
    parser.add_argument(
        '--captcha_dir',
        type=str,
        default='./data/test_data',
        help='Directory where to get captcha images.'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
