#!/usr/bin/python
"""
(c) 2009 Raphael Wimmer
<raphael.wimmer@ifi.lmu.de>

this code is open-source 
under the MIT license 
"""
import sys
import numpy
import networkx as nx
import math
from opencv.cv import *
from opencv.highgui import *


# Global Variables
storage = cvCreateMemStorage(0)
input_name = "0"
switch = "undefined" # will be True or False after field order determination
switch = True # detection is not working correctly
frames_processed = 0

detection_mode = False
normalization_mode = False
connection_mode = False

# area of interesting partial image
IMG_X = 170
IMG_Y = 80
IMG_W = 300
IMG_H = 300
IMG_SIZE = cvSize(IMG_W,IMG_H)

class Dot:
    
    def compare_acc(dot1, dot2):
        return dot1.acc - dot2.acc
    
    def __init__(self, x, y, r, acc=0, min = 255, max = 0):
        self.x = x
        self.y = y
        self.r = r
        self.val = 0 # brightness
        self.norm_val = 0.0
        self.max = max
        self.min = min
        self.acc = acc # probability of correctly recognized dot
        self.realx = 0 # gets determined by graph layout mechanism
        self.realy = 0 # gets determined by graph layout mechanism
        self.neighbors = [] #neighbors without weighting
    
    def __str__(self):
        print self.x, self.y, self.val
    
    def determineBrightness(self, frame, normalize = True):
        '''get average brightness over 5x5 pixels'''
        W = 5
        hotspot = cvGetSubRect(frame, cvRect(self.x - (W/2), self.y - (W/2), W, W))
        self.val = int(cvAvg(hotspot)[0])
        if normalize:
            if self.val > self.max:
                self.max = self.val
            if self.val < self.min:
                self.min = self.val
        if self.max > self.min:
            self.norm_val = float(self.val - self.min) / float(self.max - self.min)
        else:
            self.norm_val = 0.0
        return self.val
        
    
    def addNeighbor(self, dot):
        if dot != self and not dot in self.neighbors  :
            self.neighbors.append(dot)
            dot.addNeighbor(self)

    def addNeighbors(self, dots):
        for dot in dots:
            self.addNeighbor(dot)
 
class Calibration:
    min_radius = 13
    max_radius = 20
    max_dist_for_same = 5
    
    def __init__(self):
        self.dots = []
        self.graph = nx.Graph()
        self.max_realx = 0.0
        self.min_realx = 0.0
        self.max_realy = 0.0
        self.min_realy = 0.0
    
    def checkDot(self, x, y, r):
        if r < Calibration.min_radius or r > Calibration.max_radius:
            return False
        # else (valid dot)
        for dot in self.dots:
            # if dot already stored then average positon
            if abs(dot.x - x) < Calibration.max_dist_for_same \
               and abs(dot.y - y) < Calibration.max_dist_for_same:
                dot.x = int((dot.x + x) / 2)
                dot.y = int((dot.y + y) / 2)
                dot.r = int((dot.r + r) / 2)
                dot.acc += 1
                # return now because there should be only one dot with theses properties
                return False
            if abs(dot.x - x) < Calibration.min_radius \
               and abs(dot.y - y) < Calibration.min_radius:
                # new one is too close to existing dot but not the same - discard it
                return False
        
        # dot not yet stored:
        self.dots.append(Dot(int(x),int(y),int(r),0))
        return True
    
    def checkDots(self, circles):
        for circle in circles:
                if self.checkDot(cvRound(circle[0]), cvRound(circle[1]), cvRound(circle[2])):
                    print "new dot found (" + str(len(self.getDots())) + "): ",
                    print cvRound(circle[0]), cvRound(circle[1]), "(", cvRound(circle[2]) , ")"

    def __str__(self):
        s = ""
        for dot in self.dots:
            if dot.acc > 5:
                s = s + "[" + str(dot.x) + ", " + str(dot.y) + ", " + str(dot.r) + "," + str(dot.acc) + "],\n"
        return s
            
    def getDots(self):
        return self.dots
    
    def getDotAt(self, x, y):
        min_dist = 99999
        hit = None
        for dot in self.dots:
            dist = math.sqrt((dot.x - x)*(dot.x - x) + (dot.y - y)*(dot.y - y))
            if dist < min_dist:
                min_dist = dist
                hit = dot
        if min_dist < self.min_radius:
            return hit
        else:
            return None
    
    def getDotsAt(self, circles):
        dots = []
        for circle in circles:
            dot = self.getDotAt(cvRound(circle[0]), cvRound(circle[1]))
            if dot:
                dots.append(dot)
            else:
                print "oh, dot not recognized"
        return dots
    
    def cleanUp(self, limit = 0, threshold=5):
        # sort by acc
        sorted_dots = sorted(self.dots, cmp=Dot.compare_acc, reverse = True)
        if limit > 0:
            self.dots = sorted_dots[:limit]
        else:
            for dot in sorted_dots:
                if dot.acc < threshold:
                    sorted_dots.remove(dot)
                    print "removed dot", dot.x, dot.y, "(", dot.acc, ")"
                else:
                    print "kept dot", dot.x, dot.y, "(", dot.acc, ")"
            self.dots = sorted_dots

    def resetNeighbors(self):
        for dot in self.dots:
            # @type dot Dot
            dot.neighbors = []

    def resetNormVals(self):
        for dot in self.dots:
            # @type dot Dot
            dot.min = 255
            dot.max = 0

    def connectDots(self, threshold = 0.5):
        active_dots = []
        for dot in self.dots:
            if threshold < 1.0 and dot.norm_val > threshold:
                active_dots.append(dot)
            if threshold >= 1.0 and dot.val > threshold:
                active_dots.append(dot)
        for dot1 in active_dots:
            dot1.addNeighbors(active_dots)
            for dot2 in active_dots:
                self.graph.add_edge(dot1,dot2)

    def springLayout(self, iterations = 50, dimensions = 2):
        layout = nx.spring_layout(self.graph, iterations, dimensions)
        for (dot, position) in layout.items():
            dot.realx = float(position[0])
            dot.realy = float(position[1])
            self.min_realx = min(self.min_realx, dot.realx)
            self.min_realy = min(self.min_realy, dot.realy)
            self.max_realx = max(self.max_realx, dot.realx)
            self.max_realy = max(self.max_realy, dot.realy)
            print "Dot", dot.x, dot.y, dot.realx, dot.realy

    
    def save(self, filename="flyeye.cal"):
        f = open(filename, "w")
        f.write("# calibration data for FlyEye dots\n")
        f.write("# x ; y ; radius ; acc ; min ; max\n")
        for dot in self.dots:
            f.write(str(dot.x) + ";" + str(dot.y) + ";" + str(dot.r) + ";" +\
                    str(dot.acc) + ";" + str(dot.min) + ";" + str(dot.max) + "\n")
        f.close()
        f = open(filename + ".neighbors", "w")
        f.write("# neighbor data for FlyEye dots\n")
        f.write("# x, y ; x1, y1 ; x2, y2; ... \n")
        for dot in self.dots:
            f.write(str(dot.x) + "," + str(dot.y))
            for neighbor in dot.neighbors:
                f.write(";" + str(neighbor.x) + "," + str(neighbor.y))
            f.write("\n")
        f.close()
    
    def load(self, filename="flyeye.cal"):
        self.dots = []
        f = open(filename, "r")
        for line in f:
            if line[0] == "#":
                continue
            else: 
                x, y, r, acc, min, max = line.split(";")
                self.dots.append(Dot(int(x),int(y),int(r),int(acc), int(min), int(max)))
        f.close()

    def loadNeighbors(self, graph, filename="flyeye.cal.neighbors"):
        #positions = {}
        #TODO: extract dots from coordinates
        #TODO: also save neighbor info to dots
        f = open(filename, "r")
        for line in f:
            if line[0] == "#":
                continue
            else:
                edges = line.split(";")
                point1 = edges.pop(0).strip("\n")
                x,y = point1.split(",")
                #positions[point1] = numpy.array([int(x), int(y)])
                for point2 in edges:
                    self.graph.add_edge(point1, point2.strip("\n"))
        f.close()

            
# only static parameters because opencv does not like callbacks to object methods. stupid.
hough_bins = 1
hough_param1 = 1
hough_param2 = 1
hough_min_rad = 1
hough_max_rad = 1
hough_min_dist = 1


def cb_set(variable_wrapped_in_list, value):
    if value > 0:
        variable_wrapped_in_list[0] = value


def cb_hough_bins(val):
    global hough_bins
    if val > 0:
        hough_bins = val
        
def cb_hough_param1(val):
    global hough_param1
    if val > 0:
        hough_param1 = val
        
def cb_hough_param2(val):
    global hough_param2
    if val > 0:
        hough_param2 = val

def cb_hough_min_rad(val):
    global hough_min_rad
    if val > 0:
        hough_min_rad = val

def cb_hough_max_rad(val):
    global hough_max_rad
    if val > 0:
        hough_max_rad = val

def cb_hough_min_dist(val):
    global hough_min_dist
    if val > 0:
        hough_min_dist = val


class HoughCircleExtractor:
    
    count = 0
    
    def __init__(self, bins=1, p1=10, p2=5, min_rad=12, max_rad=20, min_dist=10):
        self.id = HoughCircleExtractor.count
        self.name = "Hough Circle Extractor " + str(self.id)
        HoughCircleExtractor.count += 1
        
        global hough_bins, hough_param1, hough_param2, hough_min_rad,hough_max_rad,hough_min_dist
        
        hough_bins = bins
        hough_param1 = p1
        hough_param2 = p2
        hough_min_rad = min_rad
        hough_max_rad = max_rad
        hough_min_dist = min_dist
        self.storage = cvCreateMemStorage(0)
        self.debug = True

        self.frame = None # don't output a new image, only modify one passed to filter()
        #cvCreateImage( cvSize(self.width,self.height), IPL_DEPTH_8U, 1 );
        
        cvNamedWindow( self.name, CV_WINDOW_AUTOSIZE );
        cvCreateTrackbar("bins", self.name, hough_bins, 8, cb_hough_bins)
        cvCreateTrackbar("p1", self.name, hough_param1, 100, cb_hough_param1)
        cvCreateTrackbar("p2", self.name, hough_param2, 100, cb_hough_param2)
        cvCreateTrackbar("min_rad", self.name, hough_min_rad, 30, cb_hough_min_rad)
        cvCreateTrackbar("max_rad", self.name, hough_max_rad, 30, cb_hough_max_rad)
        cvCreateTrackbar("min_dist", self.name, hough_min_dist, 30, cb_hough_min_dist)

    def extract(self,frame):
        circles = cvHoughCircles(frame, self.storage,
                                 CV_HOUGH_GRADIENT,
                                 hough_bins,
                                 hough_min_dist,
                                 hough_param1,
                                 hough_param2,
                                 hough_min_rad,
                                 hough_max_rad)
        if self.debug:       
            for circle in circles:
                cvCircle(frame, cvPoint(cvRound(circle[0]), cvRound(circle[1])), cvRound(circle[2]), cvScalar(255))
        cvShowImage( self.name, frame )
        return circles

#threshold_param1 = 8
#threshold_param2 = 255

def cb_threshold_param1(val):
    global threshold_param1
    threshold_param1 = val
    
def cb_threshold_param2(val):
    global threshold_param2
    threshold_param2= val


def callBack(val):
    print val

# MAIN STARTS HERE #

if __name__ != '__main__':
    exit(0)

if len(sys.argv) > 1:
    input_name = sys.argv[1]

if input_name.isdigit():
    capture = cvCreateCameraCapture( int(input_name) )
else:
    capture = cvCreateCameraCapture( 0 )


cvNamedWindow( "win1", 1 );
cvNamedWindow( "win2", 1 );
cvNamedWindow( "win_sub", 1 );
cvNamedWindow( "win_threshold", 1 );
cvNamedWindow( "win_final", 1 );
cvNamedWindow( "win_real", 1 );


# set up threshold parameters
threshold_param1 = 32
threshold_param2 = 255
cvCreateTrackbar("thresh p1", "win_threshold", threshold_param1, 255, cb_threshold_param1)
#cvCreateTrackbar("thresh p2", "win_thresholds", threshold_param2, 255, cb_threshold_param2)


if( capture ):
    cvSetCaptureProperty(capture, CV_CAP_PROP_FPS, 30.0)
    new_set = False
    
    cal = Calibration()
    hough = HoughCircleExtractor(1,10,5,13,16,20)
    
    frame_small = None
    frame_copy1 = None
    frame_copy2 = None
    frame_sub = None
    frame_grey = None
    frame_sobel = None
    frame_smooth = None
    frame_threshold = None
    frame_hough = None
    frame_final = None
    frame_real = None
    
    while True:
        
        frame = cvQueryFrame( capture )
        frame_small = cvGetSubRect(frame, cvRect(IMG_X,IMG_Y, IMG_W, IMG_H))
        
        if( not frame_copy1 ):
            frame_copy1 = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, frame.nChannels );
        
        if( not frame_copy2 ):
            frame_copy2 = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, frame.nChannels );
            
        if( not frame_final ):
            frame_final = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, frame.nChannels );

        if( not frame_real ):
            frame_real = cvCreateImage(cvSize(400,400), IPL_DEPTH_8U, 3 );
        
        # with the timestamp embedded into the camera image, successive images have an even or odd last bit.
        # the dedicated frame counter bits seem to not work with the FireFly MV.
        if (int(cvGet1D(frame, 3)[0]) % 2) == 0:
            cvCopy( frame_small, frame_copy1 );
            new_set = False
        else:
            cvCopy( frame_small, frame_copy2 );
            new_set = True
        
        
        if( not frame_sub ):
            frame_sub = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, frame.nChannels );
            
        if( not frame_sobel ):
            frame_sobel = cvCreateImage(IMG_SIZE, IPL_DEPTH_16U, 1 );
            
        if( not frame_hough ):
            frame_hough = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, 1 );
            
        if( not frame_grey ):
            frame_grey = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, 1 );
            
        if( not frame_smooth ):
            frame_smooth = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, 1 );
            
        if( not frame_threshold ):
            frame_threshold = cvCreateImage(IMG_SIZE, IPL_DEPTH_8U, 1 );
        
        cvShowImage( "win1", frame_copy1 )
        cvShowImage( "win2", frame_copy2 )
        
        if new_set:
            # order needs to be determined:        
            if switch == "undefined":
                avg1 = cvAvg(frame_copy1)
                avg2 = cvAvg(frame_copy2)
                print "frame 1:", avg1, "frame 2:", avg2
                if avg1[0] > avg2[0]:
                    switch = False
                else:
                    switch = True
            
            if switch:
                cvSub(frame_copy2, frame_copy1, frame_sub)
            else:
                cvSub(frame_copy1, frame_copy2, frame_sub)
                
            cvCvtColor(frame_sub, frame_grey, CV_RGB2GRAY)
            #cvConvertScale(frame_grey, frame_grey, 2.0, 30)
            cvSmooth(frame_grey, frame_grey, CV_GAUSSIAN, 5, 5)
            cvShowImage( "win_sub", frame_grey )
            #cvThreshold(frame_smooth, frame_threshold, threshold_param1, threshold_param2, CV_THRESH_BINARY)
            cvThreshold(frame_grey, frame_threshold, threshold_param1, threshold_param2, CV_THRESH_BINARY)
            #cvAdaptiveThreshold(frame_smooth, frame_threshold, 255, CV_ADAPTIVE_THRESH_MEAN_C, CV_THRESH_BINARY, (threshold_param1 * 2 + 3), float(threshold_param2))
       
            cvCvtColor(frame_grey, frame_final, CV_GRAY2RGB)
            #cvDilate(frame_threshold, frame_threshold)
            #cvErode(frame_threshold, frame_threshold)
            cvSmooth(frame_threshold, frame_threshold, CV_GAUSSIAN, 3, 3)
            cvShowImage( "win_threshold", frame_threshold )

            #cvSobel(frame_threshold, frame_sobel, 1, 1, 1)
            #cvDilate(frame_sobel, frame_sobel)
            #cvShowImage( "win_sobel", frame_sobel)
            
            #cvConvertScale(frame_sobel, frame_hough)
           
            cvConvertScale(frame_threshold, frame_hough)

            if detection_mode:
                circles = hough.extract(frame_hough)
                cvRectangle(frame_final, cvPoint(0,0), cvPoint(20,20), cvScalar(100,100,255), -1)
            else:
                circles = []
            cal.checkDots(circles)

            if normalization_mode:
                cvRectangle(frame_final, cvPoint(30,0), cvPoint(50,20), cvScalar(100,255, 100), -1)

            if connection_mode:
                cal.connectDots(threshold_param1)
                cvRectangle(frame_final, cvPoint(60,0), cvPoint(80,20), cvScalar(255,100,100), -1)

            cvRectangle(frame_real, cvPoint(0,0), cvPoint(frame_real.width,frame_real.height), cvScalar(0,100,100), -1)
            

            for dot in cal.getDots():
                brt = dot.determineBrightness(frame_grey, (normalization_mode == True))
                cvCircle(frame_final, cvPoint(dot.x, dot.y), dot.r, cvScalar(int(dot.norm_val * 255),int(dot.norm_val * 255), int(dot.norm_val * 255)), -1) # filled
                if dot.acc > 5:
                    cvCircle(frame_final, cvPoint(dot.x, dot.y), dot.r, cvScalar(255,0,0))
                else:
                    cvCircle(frame_final, cvPoint(dot.x, dot.y), dot.r, cvScalar(0,0,255))
                if dot.val > threshold_param1:
                    for neighbor in dot.neighbors:
                        cvLine(frame_final, cvPoint(dot.x, dot.y), cvPoint(neighbor.x, neighbor.y), cvScalar(0,255,0))

                if cal.max_realx > 0 and cal.max_realy > 0:
                    x = (dot.realx - cal.min_realx) * frame_real.width / (cal.max_realx - cal.min_realx)
                    y = (dot.realy - cal.min_realy) * frame_real.height / (cal.max_realy - cal.min_realy)
                    cvCircle(frame_real,
                             cvPoint(int(x), int(y)),
                             dot.r,
                             cvScalar(int(dot.norm_val * 255),int(dot.norm_val * 255), int(dot.norm_val * 255)),
                             -1) # filled

            
            for dot in cal.getDotsAt(circles):
                cvCircle(frame_final, cvPoint(dot.x, dot.y), dot.r, cvScalar(0,255,0))



            
            cvShowImage( "win_final", frame_final )
            cvShowImage( "win_real", frame_real )
            frames_processed += 1

        # with the timestamp embedded into the camera image, successive images have an even or odd last bit.
        # the dedicated frame counter bits seem to not work with the FireFly MV.
        #if (int(cvGet1D(frame, 3)[0]) % 2) == 0:

        key = cvWaitKey( 5 )
        if (key >= 0):
            if ord(key) == 27:
                break
            elif key == 'p':
                print cal
            elif key == 's':
                cal.save()
            elif key == 'l':
                cal.load()
            elif key == 'r':
                cal.resetNormVals()
            elif key == 't':
                cal.resetNeighbors()
            elif key == 'o':
                cal.springLayout()
            elif key == 'c':
                cal.cleanUp()
            elif key == '1':
                detection_mode = not detection_mode
            elif key == '2':
                normalization_mode = not normalization_mode
            elif key == '3':
                connection_mode = not connection_mode
            

cvDestroyWindow("win1")
cvDestroyWindow("win2")
