#!/usr/bin/env python
#
# Copyright 2003,2004 Free Software Foundation, Inc.
# 
# This file is part of GNU Radio
# 
# GNU Radio is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
# 
# GNU Radio is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with GNU Radio; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
# Boston, MA 02111-1307, USA.
# 

from gnuradio import gr, grutil, eng_notation
from gnuradio.wxgui import stdgui
from gnuradio import audio
import wx
import wx.lib.plot as plot
import Numeric
import os
import threading
import struct

# FIXME this should be rewritten to use hierarchical modules (when they're ready)

# ========================================================================
# returns (block, win).
#   block requires a N input stream of float
#   win is a subclass of wxWindow

def make_scope_sink_f (fg, parent, label, input_rate):
    (r_fd, w_fd) = os.pipe ()

    block = gr.oscope_sink_f (input_rate, w_fd)

    win = scope_window (win_info (r_fd, input_rate, block, label), parent)
    
    return (block, win)

# ========================================================================
# returns (block, win).
#   block requires a N input streams of gr_complex
#   win is a subclass of wxWindow

def make_scope_sink_c (fg, parent, label, input_rate):
    (r_fd, w_fd) = os.pipe ()

    block = gr.oscope_sink_c (input_rate, w_fd)

    win = scope_window (win_info (r_fd, input_rate, block, label), parent)

    return (block, win)

# ========================================================================


time_base_list = [                      # time / division
    1.0e-7,   # 100ns / div
    2.5e-7,
    5.0e-7,
    1.0e-6,   #   1us / div
    2.5e-6,
    5.0e-6,
    1.0e-5,   #  10us / div
    2.5e-5,
    5.0e-5,
    1.0e-4,   # 100us / div
    2.5e-4,
    5.0e-4,
    1.0e-3,   #   1ms / div
    2.5e-3,
    5.0e-3,
    1.0e-2,   #  10ms / div
    2.5e-2,
    5.0e-2
    ]

wxDATA_EVENT = wx.NewEventType()

def EVT_DATA_EVENT(win, func):
    win.Connect(-1, -1, wxDATA_EVENT, func)

class DataEvent(wx.PyEvent):
    def __init__(self, data):
        wx.PyEvent.__init__(self)
        self.SetEventType (wxDATA_EVENT)
        self.data = data

    def Clone (self): 
        self.__class__ (self.GetId())


class win_info (object):
    __slots__ = ['file_descriptor', 'sample_rate', 'scopesink', 'title',
                 'time_scale_cursor', 'marker', 'xy', 'running']

    def __init__ (self, file_descriptor, sample_rate, scopesink, title = "Oscilloscope"):
        self.file_descriptor = file_descriptor
        self.sample_rate = sample_rate
        self.scopesink = scopesink
        self.title = title;
        self.time_scale_cursor = grutil.seq_with_cursor (time_base_list)
        self.time_scale_cursor.prev ()
        self.marker = 'line'
        self.xy = False
        self.running = True

    def get_time_per_div (self):
        return self.time_scale_cursor.current ()

    def get_sample_rate (self):
        return self.sample_rate

    def get_decimation_rate (self):
        return 1.0

    def set_marker (self, s):
        self.marker = s

    def get_marker (self):
        return self.marker

        
class input_watcher (threading.Thread):
    def __init__ (self, file_descriptor, event_receiver, **kwds):
        threading.Thread.__init__ (self, **kwds)
        self.setDaemon (1)
        self.file_descriptor = file_descriptor
        self.event_receiver = event_receiver
        self.keep_running = True
        self.start ()

    def run (self):
        # print "input_watcher: pid = ", os.getpid ()
        while (self.keep_running):

            # read 2 int header
            s = os.read (self.file_descriptor, 2 * gr.sizeof_int)
            if not s:
                self.keep_running = False
                break

            nchan, nsamples = struct.unpack ('ii', s)
            
            records = []
            for ch in range (nchan):
                s = os.read (self.file_descriptor, gr.sizeof_float * nsamples)
                if not s:
                    self.keep_running = False
                    break

                rec = Numeric.fromstring (s, Numeric.Float32)
                records.append (rec)

            # print "nrecords = %d, reclen = %d" % (len (records), nsamples)
            
            de = DataEvent (records)
            wx.PostEvent (self.event_receiver, de)
            records = []
            del de
    

class scope_window (wx.Panel):

    def __init__ (self, info, parent, id = -1,
                  pos = wx.DefaultPosition, size = wx.DefaultSize, name = ""):
        wx.Panel.__init__ (self, parent, -1)
        self.info = info

        vbox = wx.BoxSizer (wx.VERTICAL)

        self.graph = graph_window (info, self, -1)

        vbox.Add (self.graph, 1, wx.EXPAND)
        vbox.Add (self.make_control_box (), 0, wx.EXPAND)

        self.sizer = vbox
        self.SetSizer (self.sizer)
        self.SetAutoLayout (True)
        self.sizer.Fit (self)
        

    def make_control_box (self):
        ctrlbox = wx.BoxSizer (wx.HORIZONTAL)

        # tb_left = wx.Button (self, 1001, "<", style=wx.BU_EXACTFIT)
        tb_left = wx.Button (self, 1001, "<")
        tb_left.SetToolTipString ("Increase time base")
        wx.EVT_BUTTON (self, 1001, self.incr_timebase)


        # tb_right  = wx.Button (self, 1000, ">", style=wx.BU_EXACTFIT)
        tb_right  = wx.Button (self, 1000, ">")
        tb_right.SetToolTipString ("Decrease time base")
        wx.EVT_BUTTON (self, 1000, self.decr_timebase)

        self.time_base_label = wx.StaticText (self, 1002, "")
        self.update_timebase_label ()

        ctrlbox.Add ((5,0) ,0)
        ctrlbox.Add (wx.StaticText (self, -1, "Horiz Scale: "), 0, wx.ALIGN_CENTER)
        ctrlbox.Add (tb_left, 0, wx.EXPAND)
        ctrlbox.Add (tb_right, 0, wx.EXPAND)
        ctrlbox.Add (self.time_base_label, 0, wx.ALIGN_CENTER)

        ctrlbox.Add ((10,0) ,1)            # stretchy space

        ctrlbox.Add (wx.StaticText (self, -1, "Trig: "), 0, wx.ALIGN_CENTER)
        self.trig_chan_choice = wx.Choice (self, 1004,
                                           choices = ['Ch1', 'Ch2', 'Ch3', 'Ch4'])
        self.trig_chan_choice.SetToolTipString ("Select channel for trigger")
        wx.EVT_CHOICE (self, 1004, self.trig_chan_choice_event)
        ctrlbox.Add (self.trig_chan_choice, 0, wx.ALIGN_CENTER)

        self.trig_mode_choice = wx.Choice (self, 1005,
                                           choices = ['Pos', 'Neg', 'Auto'])
        self.trig_mode_choice.SetToolTipString ("Select trigger slope or Auto (untriggered roll)")
        wx.EVT_CHOICE (self, 1005, self.trig_mode_choice_event)
        ctrlbox.Add (self.trig_mode_choice, 0, wx.ALIGN_CENTER)

        trig_level50 = wx.Button (self, 1006, "50%")
        trig_level50.SetToolTipString ("Set trigger level to 50%")
        wx.EVT_BUTTON (self, 1006, self.set_trig_level50)
        ctrlbox.Add (trig_level50, 0, wx.EXPAND)

        run_stop = wx.Button (self, 1007, "Run/Stop")
        run_stop.SetToolTipString ("Toggle Run/Stop mode")
        wx.EVT_BUTTON (self, 1007, self.run_stop)
        ctrlbox.Add (run_stop, 0, wx.EXPAND)

        ctrlbox.Add ((10, 0) ,1)            # stretchy space

        ctrlbox.Add (wx.StaticText (self, -1, "Fmt: "), 0, wx.ALIGN_CENTER)
        self.marker_choice = wx.Choice (self, 1002, choices = self._marker_choices)
        self.marker_choice.SetToolTipString ("Select plotting with lines, pluses or dots")
        wx.EVT_CHOICE (self, 1002, self.marker_choice_event)
        ctrlbox.Add (self.marker_choice, 0, wx.ALIGN_CENTER)

        self.xy_choice = wx.Choice (self, 1003, choices = ['X:t', 'X:Y'])
        self.xy_choice.SetToolTipString ("Select X vs time or X vs Y display")
        wx.EVT_CHOICE (self, 1003, self.xy_choice_event)
        ctrlbox.Add (self.xy_choice, 0, wx.ALIGN_CENTER)

        return ctrlbox
    
    _marker_choices = ['line', 'plus', 'dot']

    def update_timebase_label (self):
        time_per_div = self.info.get_time_per_div ()
        s = ' ' + eng_notation.num_to_str (time_per_div) + 's/div'
        self.time_base_label.SetLabel (s)
        
    def decr_timebase (self, evt):
        self.info.time_scale_cursor.prev ()
        self.update_timebase_label ()

    def incr_timebase (self, evt):
        self.info.time_scale_cursor.next ()
        self.update_timebase_label ()
        
    def marker_choice_event (self, evt):
        s = evt.GetString ()
        self.set_marker (s)

    def set_marker (self, s):
        self.info.set_marker (s)        # set info for drawing routines
        i = self.marker_choice.FindString (s)
        assert i >= 0, "Hmmm, set_marker problem"
        self.marker_choice.SetSelection (i)

    def set_format_line (self):
        self.set_marker ('line')

    def set_format_dot (self):
        self.set_marker ('dot')

    def set_format_plus (self):
        self.set_marker ('plus')
        
    def xy_choice_event (self, evt):
        s = evt.GetString ()
        self.info.xy = s == 'X:Y'

    def trig_chan_choice_event (self, evt):
        s = evt.GetString ()
        ch = int (s[-1]) - 1
        self.info.scopesink.set_trigger_channel (ch)

    def trig_mode_choice_event (self, evt):
        sink = self.info.scopesink
        s = evt.GetString ()
        if s == 'Pos':
            sink.set_trigger_mode (gr.gr_TRIG_POS_SLOPE)
        elif s == 'Neg':
            sink.set_trigger_mode (gr.gr_TRIG_NEG_SLOPE)
        elif s == 'Auto':
            sink.set_trigger_mode (gr.gr_TRIG_AUTO)
        else:
            assert 0, "Bad trig_mode_choice string"
    
    def set_trig_level50 (self, evt):
        self.info.scopesink.set_trigger_level_auto ()

    def run_stop (self, evt):
        self.info.running = not self.info.running
        

class graph_window (plot.PlotCanvas):

    channel_colors = ['GREEN', 'RED', 'BLUE',
                      'CYAN', 'MAGENTA', 'YELLOW']
    
    def __init__ (self, info, parent, id = -1,
                  pos = wx.DefaultPosition, size = wx.DefaultSize,
                  style = wx.DEFAULT_FRAME_STYLE, name = ""):
        plot.PlotCanvas.__init__ (self, parent, id, pos, size, style, name)

        # self.SetXUseScopeTicks (True)
        self.SetEnableGrid (True)
        self.SetEnableZoom (True)
        # self.SetBackgroundColour ('black')
        
        self.info = info;
        self.y_range = None
        self.x_range = None
        self.avg_y_min = None
        self.avg_y_max = None
        self.avg_x_min = None
        self.avg_x_max = None

        EVT_DATA_EVENT (self, self.format_data)

        self.input_watcher = input_watcher (info.file_descriptor,
                                            self)

    def channel_color (self, ch):
        return self.channel_colors[ch % len(self.channel_colors)]
        
    def format_data (self, evt):
        if not self.info.running:
            return
        
        if self.info.xy:
            self.format_xy_data (evt)
            return

        info = self.info
        records = evt.data
        nchannels = len (records)
        npoints = len (records[0])

        objects = []

        Ts = 1.0 / (info.get_sample_rate () / info.get_decimation_rate ())
        x_vals = Ts * Numeric.arrayrange (-npoints/2, npoints/2)

        # preliminary clipping based on time axis here, instead of in graphics code
        time_per_window = self.info.get_time_per_div () * 10
        n = int (time_per_window / Ts + 0.5)
        n = n & ~0x1                    # make even
        n = max (2, min (n, npoints))

        # self.SetXUseScopeTicks (True)   # use 10 divisions, no labels

        for ch in range(nchannels):
            r = records[ch]

            # plot middle n points of record

            lb = npoints/2 - n/2
            ub = npoints/2 + n/2
            points = zip (x_vals[lb:ub], r[lb:ub])

            m = info.get_marker ()
            if m == 'line':
                objects.append (plot.PolyLine (points,
                                                   colour=self.channel_color (ch)))
            else:
                objects.append (plot.PolyMarker (points,
                                                     marker=m,
                                                     colour=self.channel_color (ch)))

        graphics = plot.PlotGraphics (objects,
                                          title=self.info.title,
                                          xLabel = '', yLabel = '')

        time_per_div = info.get_time_per_div ()
        x_range = (-5.0 * time_per_div, 5.0 * time_per_div)
        self.Draw (graphics, xAxis=x_range, yAxis=self.y_range)
        self.update_y_range ()


    def format_xy_data (self, evt):
        info = self.info
        records = evt.data
        nchannels = len (records)
        npoints = len (records[0])

        if nchannels < 2:
            return

        objects = []
        points = zip (records[0], records[1])
        
        # self.SetXUseScopeTicks (False)

        m = info.get_marker ()
        if m == 'line':
            objects.append (plot.PolyLine (points,
                                           colour=self.channel_color (0)))
        else:
            objects.append (plot.PolyMarker (points,
                                             marker=m,
                                             colour=self.channel_color (0)))

        graphics = plot.PlotGraphics (objects,
                                      title=self.info.title,
                                      xLabel = 'I', yLabel = 'Q')

        self.Draw (graphics, xAxis=self.x_range, yAxis=self.y_range)
        self.update_y_range ()
        self.update_x_range ()


    def update_y_range (self):
        alpha = 1.0/25
        graphics = self.last_draw[0]
        p1, p2 = graphics.boundingBox ()     # min, max points of graphics

        if self.avg_y_min:
            self.avg_y_min = p1[1] * alpha + self.avg_y_min * (1 - alpha)
            self.avg_y_max = p2[1] * alpha + self.avg_y_max * (1 - alpha)
        else:
            self.avg_y_min = p1[1]
            self.avg_y_max = p2[1]

        self.y_range = self._axisInterval ('auto', self.avg_y_min, self.avg_y_max)


    def update_x_range (self):
        alpha = 1.0/25
        graphics = self.last_draw[0]
        p1, p2 = graphics.boundingBox ()     # min, max points of graphics

        if self.avg_x_min:
            self.avg_x_min = p1[0] * alpha + self.avg_x_min * (1 - alpha)
            self.avg_x_max = p2[0] * alpha + self.avg_x_max * (1 - alpha)
        else:
            self.avg_x_min = p1[0]
            self.avg_x_max = p2[0]

        self.x_range = self._axisInterval ('auto', self.avg_x_min, self.avg_x_max)


# ----------------------------------------------------------------
# Stand-alone test application
# ----------------------------------------------------------------

class test_app_flow_graph (stdgui.gui_flow_graph):
    def __init__(self, frame, panel, vbox, argv):
        stdgui.gui_flow_graph.__init__ (self, frame, panel, vbox, argv)

        # build our flow graph
        input_rate = 44100 #1e6
	src0 = audio.source (input_rate)
        #src0 = gr.sig_source_f (input_rate, gr.GR_SIN_WAVE, 25.1e3, 1e3)
        src1 = gr.sig_source_f (input_rate, gr.GR_COS_WAVE, 25.1e3, 1e3)
        block, fft_win = make_scope_sink_f (self, panel, "Secret Data", input_rate)
        self.connect (src0, (block, 0))
        #self.connect (src1, (block, 1))
        vbox.Add (fft_win, 1, wx.EXPAND)


def main ():
    app = stdgui.stdapp (test_app_flow_graph, "O'Scope Test App")
    app.MainLoop ()

if __name__ == '__main__':
    main ()

# ----------------------------------------------------------------

