#!/usr/bin/env python

import sys
import threading
import Image
import ImageChops
import ImageOps
import gc 

import pygtk; pygtk.require("2.0")
import gtk
gtk.gdk.threads_init()
import gobject

import pygst; pygst.require("0.10")
import gst

BLACK = gtk.gdk.Color(0,0,0)
outputs = ("flumpeg2vdec", "mpeg2dec", "Difference")

decoders = {}
cur_out = None
status_label = None
draw_area = None

class Decoder:
    def __init__(self, draw_area, file = None, my_name = None):
	self.buffer = None
	self.pipe = None

	self.draw_area = draw_area

	self.cond = threading.Condition()
	self.target_frame_no = 0
	self.cur_frame_no = 0

	self.filename = file
	self.name = my_name

	self.started = False
        self.flushing = False

	self.width = 200
	self.height = 200
	self.data = None
	self.status = ""

	if file is not None:
	    self.make_pipeline (file, my_name)

    def make_pipeline (self, file, my_name):
	# Build a GStreamer pipeline to decode file to fakesink, 
	# attach a handoff handler to catch each frame
	# Set up a condition to wait on after each frame, fired from the 
	# next_frame item
	p = gst.Pipeline()
	f = gst.element_factory_make ("filesrc")
	f.set_property ("location", file)

	if my_name == "flumpeg2vdec":
	    m = gst.element_factory_make ("flumpeg2vdec")
	else:
	    m = gst.element_factory_make ("mpeg2dec")
	ffc = gst.element_factory_make ("ffmpegcolorspace")
	fs = gst.element_factory_make ("fakesink")
	
	p.add (f, m, ffc, fs)
	f.link (m)
	m.link (ffc)
	ffc.link (fs, gst.caps_from_string ("video/x-raw-rgb,depth=24,bpp=24,red_mask=16711680"))

	# Pipeline built. Attach to signal.
	fs.set_property ("signal-handoffs", True)
	fs.connect ("handoff", self.got_frame)

	self.pipe = p

    def start(self):
	if not self.started:
	    print "Starting pipeline for", self.name
	    self.flushing = False
	    self.pipe.set_state (gst.STATE_PLAYING)
	    self.started = True
	# start
    def reset(self):
	if self.started:

	    # Release any pending wait
	    self.cond.acquire() 
	    try:
		self.flushing = True
		self.cond.notify()
	    finally:
		self.cond.release()

	    if self.pipe.set_state (gst.STATE_READY) == gst.STATE_CHANGE_ASYNC:
		print "Going to READY ASYNC"
	    self.flushing = False

	    self.cur_frame_no = 0
	    self.target_frame_no = 0
	    self.pipe.set_state (gst.STATE_PLAYING)
    def go_to_frame(self, frame_no):
	if frame_no < 1:
	    frame_no = 1

	cur_frame = 0

	self.cond.acquire() 
	cur_frame = self.cur_frame_no
	self.cond.release() 

	if frame_no < cur_frame:
	    # Need to go backward
	    self.reset()

	self.cond.acquire() 
	try:
	    cur_frame = self.cur_frame_no
	    # Now set the target frame
	    self.target_frame_no = frame_no 
	    if frame_no != cur_frame:
		if self.pipe is not None:
		    print "%s going to frame %d" % (self.name, frame_no)
		    self.cond.notify()
		    self.start()
		    self.status = "Waiting..."
	finally:
	    self.cond.release()

    def calc_difference (self, one, two):
	# Compute the difference of the 2 outputs
	self.width = max (one.width, two.width)
	self.height = max (one.height, two.height)
	i1 = Image.fromstring ("RGB", (one.width, one.height), one.data)
	i2 = Image.fromstring ("RGB", (two.width, two.height), two.data)

	self.buffer = ImageChops.difference (i1, i2)
	self.buffer = ImageOps.autocontrast (self.buffer)
	self.data = self.buffer.tostring()
	self.cur_frame_no = min (one.cur_frame_no, two.cur_frame_no)

    def got_frame (self, element, buffer, pad):
	self.cond.acquire()

	try:
	    # Drop frames if there's more than one pending...
	    if self.cur_frame_no + 1 < self.target_frame_no or self.flushing:
		self.cur_frame_no += 1
		print "Dropping buffer %d on pad %s for %s" % (self.cur_frame_no, pad.get_name(), self.name)
	    else:
		if self.target_frame_no <= self.cur_frame_no: 
		    # No buffer requested yet, pend
		    print "Pending on cond for", self.name
		    self.cond.wait()
		    print "Finished pend on cond for", self.name
		if self.cur_frame_no + 1 < self.target_frame_no or self.flushing:
		    # Exit if we woke up too late
		    self.cur_frame_no += 1
		    return True

		# Handle the buffer
		gtk.gdk.threads_enter()
		try:
		    self.cur_frame_no += 1

		    caps = buffer.get_caps ()
		    if caps is not None and len (caps):
			self.width = caps[0]["width"]
			self.height = caps[0]["height"]

		    print "Got buffer %d on pad %s for %s" % (self.cur_frame_no, pad.get_name(), self.name), "wxh", self.width, self.height
		
		    if self.width > 0 and self.height > 0:
			# Paint the buffer into the window
			self.status = ""
			self.buffer = buffer
			self.data = buffer.data
    
			update_output (self)
		finally:
		    gtk.gdk.threads_leave()
		gc.collect()
	finally:
	    self.cond.release()
	# got_frame
	return True

def expose_event (widget, event):
    global cur_out, decoders

    dec = decoders[cur_out]

    # Draw pixmap centred on black
    if dec is None or dec.data is None:
        return True

    # widget.window.draw_rectangle(widget.get_style().white_gc, True, 0, 0, 500, 500)

    y_dest_offset = 0
    x_dest_offset = 0

    x, y, width, height = widget.allocation
    pw = dec.width
    ph = dec.height
    if pw < width:
        x_dest_offset = (width - pw) / 2

    if ph < height:
        y_dest_offset = (height - ph) / 2

    # Ready to draw
    x, y, width, height = event.area
    gc = widget.get_style().fg_gc[gtk.STATE_NORMAL]

    widget.window.draw_rgb_image (gc, 
        x_dest_offset, y_dest_offset, dec.width, dec.height,
        gtk.gdk.RGB_DITHER_NONE, dec.data, dec.width * 3)

    return False

def next_frame(widget=None, dir=1):
    gtk.gdk.threads_leave()
    try:
	decs = decoders.values()
	f = decs[0].target_frame_no + dir

	for d in decs:
	    d.go_to_frame (f)
    finally:
	gtk.gdk.threads_enter()

def update_output (decoder):
    global draw_area, decoders, cur_out, status_label
    one = decoders[outputs[0]]
    two = decoders[outputs[1]]
    diff = decoders[outputs[2]]
    forced = True

    if decoder.name != "Difference":
	# Either source image is updated
	if one.cur_frame_no == two.cur_frame_no and diff.cur_frame_no != one.cur_frame_no:
	    diff.calc_difference (one, two)
	    forced = True

    if (decoder.name == cur_out) or forced:
	draw_area.queue_draw()
	draw_area.set_size_request(decoder.width, decoder.height)
	status_label.set_text ("Frame %d      %s" % (decoder.cur_frame_no, decoder.status))

def output_select (combo):
    global cur_out, decoders
    cur_out = combo.get_active_text()
    update_output (decoders[cur_out])

def main(file):
    global draw_area, cur_out, status_label, decoders

    win = gtk.Window()
    win.set_name("Video Diff")
    win.connect("destroy", lambda w: gtk.main_quit())

    vbox = gtk.VBox()
    vbox.set_spacing (3)
    win.add(vbox)

    cur_out = outputs[0]

    # Drop down box
    h = gtk.HBox()
    vbox.pack_start (h, expand=False, fill=False)
    
    l = gtk.Label ("Current Output:")
    h.pack_start (l, expand=False, fill=False)

    l = gtk.combo_box_new_text ()
    for o in outputs:
	l.append_text (o)
    l.set_active (0)
    l.connect ("changed", output_select)
    h.pack_start (l, expand=True, fill=True)

    # Video window
    draw_area  = gtk.DrawingArea()
    draw_area.set_size_request(200, 200)
    draw_area.connect("expose_event", expose_event)
    vbox.pack_start (draw_area, expand=True, fill=True)
	
    status_label = gtk.Label("")
    vbox.pack_start (status_label, expand=False, fill=False)

    # Video outputs
    decoders[outputs[0]] = Decoder (draw_area, file, outputs[0])
    decoders[outputs[1]] = Decoder (draw_area, file, outputs[1])
    decoders[outputs[2]] = Decoder (draw_area)

    hbox = gtk.HBox()
    button = gtk.Button(stock=gtk.STOCK_QUIT)
    button.connect("clicked", lambda widget, win=win: win.destroy())
    hbox.pack_end(button, expand=False, fill=False)

    button = gtk.Button(label="Prev Frame")
    button.connect("clicked", next_frame, -1)
    hbox.pack_start(button, expand=False, fill=False)

    button = gtk.Button(label="Next Frame")
    button.connect("clicked", next_frame, 1)
    hbox.pack_start(button, expand=False, fill=False)

    vbox.pack_end(hbox, expand=False, fill=False)

    win.show_all()
    gtk.gdk.threads_enter()
    try:
	next_frame()
        gtk.main()
    finally:
        gtk.gdk.threads_leave()

if __name__ == '__main__':
	if len(sys.argv) < 2:
	    print "Usage: mpeg-diff filename.mpg"
	    sys.exit(1)

	main(sys.argv[1])
