#!/usr/bin/env python3

import concurrent.futures
from PIL import Image
import numpy
import os
import sys
import struct
import math
import datetime

OFFSET_BYTES = 46 #64-TGA_HEAD_SIZE
ALIGN_PIXEL = 16
TGA_HEAD_SIZE = 18
ALPHA_BLEND = 1

def help_png2vglite():
    print("Usage: png2vglite.py <input dir> <output dir> <image format>")
    print("  <input dir>:    The source directory.")
    print("  <output dir>:   The target output directory.")
    print("  <image format>: Used for image format covert.")
    print("                  Support config: BITMAP_ARGB8888, L8_ARGB8888.")
    print("                  BITMAP_ARGB8888     The images will be converted to 24-bit colors bitmaps, size reference: 480x480 storage space 901KB.")
    print("                  L8_ARGB8888         The images will be converted to thumbnails with own palette of up to 256 different colors chosen from the 24-bit RGB color space, size reference: 480x480 storage space 225KB.")
    print("                  L8_ARGB8888_DAVE2D  Same as L8_ARGB8888, but without Alpha blend and 16 pixel alignment.")

# BITMAP_ARGB8888 Start

def formatForBitmapInternal(matrix, width, height):
    padding = int(math.floor((width + ALIGN_PIXEL - 1) / ALIGN_PIXEL) * ALIGN_PIXEL - width)
    #matrix[:,:,0] = numpy.round(matrix[:,:,0] * (matrix[:,:,3] / 255.0))
    #matrix[:,:,1] = numpy.round(matrix[:,:,1] * (matrix[:,:,3] / 255.0))
    #matrix[:,:,2] = numpy.round(matrix[:,:,2] * (matrix[:,:,3] / 255.0))
    dst = numpy.zeros((height, width+padding, 4), dtype=numpy.uint8)
    for i in range(height):
        for j in range(width+padding):
            if j < width:
                (r,g,b,a) = matrix[i, j]
                dst[i, j] = (int(round(int(b) * a / 255.0)), int(round(int(g) * a / 255.0)), int(round(int(r) * a / 255.0)), a)

    return numpy.uint8(dst), width, padding

def formatForBitmap(img):
    width = img.width
    height = img.height
    matrix = numpy.array(img)
    barray, w, padding = formatForBitmapInternal(matrix, width, height)
    return Image.fromarray(barray), w, padding

def changeOffsetAndWidthBitmap(fname, width, padding):
    with open(fname, 'rb+') as f:
        data = f.read()
        f.seek(0)
        '''change offset'''
        f.write(struct.pack('B', OFFSET_BYTES))
        f.write(data[1:TGA_HEAD_SIZE])
        '''change width'''
        f.seek(12)
        f.write(struct.pack('H', width + padding))
        f.seek(TGA_HEAD_SIZE)
        '''original width'''
        f.write(struct.pack('I', 0x484D4F53))
        f.write(struct.pack('I', width))
        for i in range(OFFSET_BYTES - 8):
            f.write(struct.pack('B', 0))
        f.write(data[TGA_HEAD_SIZE:])
        f.close()

def processImageBitmap(fList):
    src = fList[0]
    dstFname = fList[1]
    img = Image.open(src)
    img = img.convert('RGBA')
    img2, w, padding = formatForBitmap(img)
    img2.save(dstFname, 'TGA', orientation=1)
    changeOffsetAndWidthBitmap(dstFname, w, padding)
    print('\t- convert: %s %s' % (src, img.mode))

# BITMAP_ARGB8888 End

def genPaletteAndGetZeroIdx(img):
    zeroidx = -1
    palette_bytes = b''
    palette = img.palette.palette
    if img.palette.mode == 'RGBA':
        palette_sz = int(len(palette) / 4)
        for i in range(palette_sz):
            r, g, b, a = palette[i * 4 : i * 4 + 4]
            if ALPHA_BLEND > 0:
                r,g,b = int(round(int(r) * a / 255.0)), int(round(int(g) * a / 255.0)), int(round(int(b) * a / 255.0))
                #v = 255 * 16 * 1024 * 1024 + r * 65536 + g * 256 + b
                v = a * 16 * 1024 * 1024 + b * 65536 + g * 256 + r
            else:
                v = a * 16 * 1024 * 1024 + r * 65536 + g * 256 + b
            palette_bytes += struct.pack('I', v)
            if v == 0:
                zeroidx = i
    elif img.palette.mode == 'RGB':
        palette_sz = int(len(palette) / 3)
        for i in range(palette_sz):
            r,g,b = palette[i*3:i*3+3]
            v = 255 * 16 * 1024 * 1024 + b * 65536 + g * 256 + r
            palette_bytes += struct.pack('I', v)
            if r == 0 and g == 0 and b == 0:
                zeroidx = i
    else:
        raise RuntimeError('unsupport palette mode!!!')

    return palette_bytes, zeroidx

def formatForL8Internal(matrix, width, height, zeroidx):
    padding = int(math.floor((width + ALIGN_PIXEL - 1) / ALIGN_PIXEL) * ALIGN_PIXEL - width)
    if padding > 0 and zeroidx == -1:
        zeroidx = 0
        #raise RuntimeError('has no zero idx in palette!!!')
    dst = numpy.zeros((height, (width+padding)), dtype=numpy.uint8)
    for i in range(height):
        for j in range(width+padding):
            if j < width:
                dst[i, j] = matrix[i, j]
            else:
                dst[i, j] = zeroidx

    return numpy.uint8(dst), width, padding

def formatForL8(img, zeroidx):
    width = img.width
    height = img.height
    matrix = numpy.array(img)
    barray, w, padding = formatForL8Internal(matrix, width, height, zeroidx)
    im = Image.fromarray(barray, 'P')
    im.putpalette(img.getpalette())
    return im, w, padding

def changeOffsetAndWidthL8(img, fname, width, padding, palette_bytes):
    with open(fname, 'rb+') as f:
        data = f.read()
        f.seek(0)
        '''change offset'''
        f.write(struct.pack('B', OFFSET_BYTES))
        f.write(data[1:TGA_HEAD_SIZE])
        '''change width'''
        f.write(struct.pack('H', width + padding))
        f.seek(TGA_HEAD_SIZE)
        '''change palette bpp'''
        f.seek(7)
        f.write(struct.pack('B', 32))
        f.seek(TGA_HEAD_SIZE)
        '''original width'''
        f.write(struct.pack('I', 0x484D4F53))
        f.write(struct.pack('I', width))
        for i in range(OFFSET_BYTES - 8):
            f.write(struct.pack('B', 0))
        pallete_offset, = struct.unpack('H', data[3:5])
        pallete_sz, = struct.unpack('H', data[5:7])
        pallete_bpp, = struct.unpack('B', data[7:8])
        pallete_bpp = int(pallete_bpp / 8)
        '''write argb palette'''
        f.write(palette_bytes)
        f.write(data[TGA_HEAD_SIZE+pallete_offset+pallete_bpp*pallete_sz:])
        f.close()

def processImageL8(fList):
    src = fList[0]
    dstFname = fList[1]
    img = Image.open(src)
    img = img.convert('P', palette=Image.ADAPTIVE)
    palette_bytes, zeroidx = genPaletteAndGetZeroIdx(img)
    img2, w, padding = formatForL8(img, zeroidx)
    img2.save(dstFname, 'TGA', orientation=1)
    changeOffsetAndWidthL8(img, dstFname, w, padding, palette_bytes)
    print('\t- convert: %s %s %s' % (src, img.mode, img.palette.mode))

def processImageDir(rootdir, outdir, imgformat):
    print('Start processImageDir format:%s src:%s dst:%s' % (imgformat, rootdir, outdir))
    global ALIGN_PIXEL
    global ALPHA_BLEND
    if 'BITMAP_ARGB8888' in imgformat:
        print('Convert images to BITMAP_ARGB8888...')
        ALIGN_PIXEL = 16
        ALPHA_BLEND = 1
    elif 'L8_ARGB8888_DAVE2D' in imgformat: 
        print('Convert images to L8_ARGB8888_DAVE2D...')
        ALIGN_PIXEL = 1
        ALPHA_BLEND = 0
    elif 'L8_ARGB8888' in imgformat:
        print('Convert images to L8_ARGB8888...')
        ALIGN_PIXEL = 16
        ALPHA_BLEND = 1
    else:
        print('ERROR. image format is %s.' % (imgformat))
        help_png2vglite()
        print('processImageDir Abort.')
        return

    flist = []
    for root, subdirs, files in os.walk(rootdir):
        for filename in files:
            src = os.path.join(root, filename)
            if src.endswith(".png"):
                path = os.path.join(outdir, os.path.relpath(root, rootdir))
                if False == os.path.exists(path):
                    os.makedirs(path)
                flist.append([src, os.path.join(path, filename)])
            else:
                print('\t- skipped: %s' % (src))

    if 'BITMAP_ARGB8888' in imgformat:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            executor.map(processImageBitmap, flist)
    elif 'L8_ARGB8888_DAVE2D' in imgformat:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            executor.map(processImageL8, flist)
    elif 'L8_ARGB8888' in imgformat:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            executor.map(processImageL8, flist)
    print('processImageDir Finished.')

if __name__ == '__main__':
    if len(sys.argv) != 4:
        help_png2vglite()
        sys.exit(-1)

    processImageDir(sys.argv[1], sys.argv[2], sys.argv[3])