import numpy as np   
import matplotlib.pyplot as plt  
import matplotlib

def imread_int( filename ):
    print "processing '" + filename + "' ..."
    img = plt.imread( filename )
    return np.array( img*255, dtype = 'int')

# grayscale values must be in range [0,255]
def histogram(grayscale, bins=256):
    h = np.zeros((bins))
    for i in xrange(grayscale.shape[0]):
        for j in xrange(grayscale.shape[1]):
            h[grayscale[i][j]] += 1
    return h

def localHistogram(grayscale, pivot, radius, bins=256):
    h = np.zeros((bins))
    for i in xrange(max(pivot[0]-radius,0), min(pivot[0]+radius,grayscale.shape[0])):
        for j in xrange(max(pivot[1]-radius,0), min(pivot[1]+radius,grayscale.shape[1])):
            h[grayscale[i][j]] += 1
    return h

def cumulativeDistr(histogram, slope_max = 0):
    d = np.zeros((histogram.size))
    c = 0
    for i in xrange(histogram.size):
        c += histogram[i]
        d[i] = c
    d /= c
    if (slope_max > 0):
        dh = 0
        for i in xrange(d.size-1):
            dh += max(d[i+1]-dh-slope_max-d[i],0)
            d[i+1] -= dh
        #d += 1-d[d.size-1]
    return d


def equalizeHistogram(grayscale):
    h = histogram(grayscale)
    d = cumulativeDistr(h)
    img = np.zeros(grayscale.shape)
    for i in xrange(grayscale.shape[0]):
        for j in xrange(grayscale.shape[1]):
            img[i][j] = d[grayscale[i][j]]
    return img

def adaptiveHistogramEqualize(grayscale, radius = 8):
    img = np.zeros(grayscale.shape)
    # brute force method
    for i in xrange(grayscale.shape[0]):
        for j in xrange(grayscale.shape[1]):
            d = cumulativeDistr(localHistogram(grayscale, (i,j), radius))
            img[i][j] = d[grayscale[i][j]]
        if (i % 10 == 0):
            print str(i*100/grayscale.shape[0]) + "%"    
    return img

def contrastLimitedAdaptiveHistEqual(grayscale, slope_max, radius = 8 ):
    img = np.zeros(grayscale.shape)
    # brute force method
    for i in xrange(grayscale.shape[0]):
        for j in xrange(grayscale.shape[1]):
            d = cumulativeDistr(localHistogram(grayscale, (i,j), radius), slope_max)
            img[i][j] = d[grayscale[i][j]]
        if (i % 10 == 0):
            print str(i*100/grayscale.shape[0]) + "%"    
    return img



def exc3a():    
    plt.subplot(421)    
    plt.imshow(plt.imread('u2Images/linearize1.png'))
    plt.gray()
    plt.subplot(422)
    plt.imshow(equalizeHistogram(imread_int('u2Images/linearize1.png')))
    
    plt.subplot(423)    
    plt.imshow(plt.imread('u2Images/linearize2.png'))
    plt.gray()
    plt.subplot(424)
    plt.imshow(equalizeHistogram(imread_int('u2Images/linearize2.png')))
    
    plt.subplot(425)    
    plt.imshow(plt.imread('u2Images/linearize3.png'))
    plt.gray()
    plt.subplot(426)
    plt.imshow(equalizeHistogram(imread_int('u2Images/linearize3.png')))
    
    plt.subplot(427)    
    plt.imshow(plt.imread('u2Images/car256.png'))
    plt.gray()
    plt.subplot(428)
    plt.imshow(equalizeHistogram(imread_int('u2Images/car256.png')))
    plt.show()  
    
def exc3b(imgpath):
    plt.imshow(adaptiveHistogramEqualize(imread_int(imgpath)))
    plt.gray()
    plt.show()
    
def exc3c(imgpath):
    plt.gray()
    
    plt.subplot(221)
    imgclahe = contrastLimitedAdaptiveHistEqual(imread_int(imgpath), 5.0/255)
    plt.imshow(imgclahe)    
    plt.subplot(222)
    plt.hist(imgclahe.reshape(imgclahe.size,1) , bins=1024, color='black')
    
    plt.subplot(223)
    img = plt.imread(imgpath)
    plt.imshow(img)
    plt.subplot(224)    
    plt.hist(img.reshape(img.size,1) , bins=1024, color='black')    
    
    plt.show()

#exc3a()    

#exc3b('u2Images/linearize1.png')
#exc3b('u2Images/linearize2s.png')    
#exc3b('u2Images/linearize3s.png')
#exc3b('u2Images/car256.png')

#exc3c('u2Images/linearize1.png')
#exc3c('u2Images/linearize2s.png')    
#exc3c('u2Images/linearize3s.png')
exc3c('u2Images/car256.png')