# -------------------------------------------------------------------------
#     This file is part of mMass - the spectrum analysis tool for MS.
#     Copyright (C) 2005-07 Martin Strohalm <mmass@biographics.cz>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 2 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE in the
#     main directory of the program
# -------------------------------------------------------------------------

# Function: Load and parse data from mMass format.

# load libs
import os.path
import xml.dom.minidom
import base64
import zlib
import struct


class mMassDoc:
    """ Get and format data from mMass document. """

    # ----
    def __init__(self):
        self.data = {
                    'docType': 'mSD',
                    'scanID':'',
                    'date':'',
                    'operator':'',
                    'contact':'',
                    'institution':'',
                    'instrument':'',
                    'notes':'',
                    'peaklist':[],
                    'seqtitle':'',
                    'sequence':'',
                    'modifications':[],
                    'spectrum':[],
                    }
    # ----


    # ----
    def getDocument(self, path):
        """ Read and parse all data from document. """

        # parse XML
        try:
            document = xml.dom.minidom.parse(path)
        except:
            return False

        # check document type
        element = document.getElementsByTagName('mSD')
        if not element:
            element = document.getElementsByTagName('mMassDoc')
        if not element:
            return False

        # get description
        element = document.getElementsByTagName('description')
        if element:
            self.handleDescription(element[0])

        # get peaklist
        element = document.getElementsByTagName('peaklist')
        if element:
            self.handlePeaklist(element[0])

        # get sequence
        element = document.getElementsByTagName('sequences')
        if element:
            self.handleSequences(element[0])

        # get spectrum
        element = document.getElementsByTagName('spectrum')
        if element:
            self.handleSpectrum(element[0])

        return self.data
    # ----


    # ----
    def getElement(self, name, path):
        """ Read and parse selected elements' data from document. """

        # parse XML
        try:
            document = xml.dom.minidom.parse(path)
        except:
            return False

        # check document type
        element = document.getElementsByTagName('mSD')
        if not element:
            element = document.getElementsByTagName('mMassDoc')
        if not element:
            return False

        # get data
        element = document.getElementsByTagName(name)
        if element:
            if name == 'description':
                if not self.handleDescription(element[0]):
                    return False
            elif name == 'peaklist':
                if not self.handlePeaklist(element[0]):
                    return False
            elif name == 'sequences':
                if not self.handleSequences(element[0]):
                    return False
            elif name == 'spectrum':
                if not self.handleSpectrum(element[0]):
                    return False

        return self.data
    # ----


    # ----
    def handleDescription(self, elements):
        """ Get document description from <description> element. """

        # get date
        element = elements.getElementsByTagName('date')
        if element:
            self.data['date'] = element[0].getAttribute('value')

        # get operator
        element = elements.getElementsByTagName('operator')
        if element:
            self.data['operator'] = element[0].getAttribute('value')

        # get contact
        element = elements.getElementsByTagName('contact')
        if element:
            self.data['contact'] = element[0].getAttribute('value')

        # get institution
        element = elements.getElementsByTagName('institution')
        if element:
            self.data['institution'] = element[0].getAttribute('value')

        # get instrument
        element = elements.getElementsByTagName('instrument')
        if element:
            self.data['instrument'] = element[0].getAttribute('value')

        # get notes
        element = elements.getElementsByTagName('notes')
        if element:
            self.data['notes'] =  self.getNodeText(element[0].childNodes)
            
        return True
    # ----


    # ----
    def handlePeaklist(self, elements):
        """ Get peaks from <peaklist> element. """

        # get peaks
        peaklist = []
        for element in elements.childNodes:
            if element.nodeName == 'peak':

                # get attributes
                mass = element.getAttribute('mass')
                intens = element.getAttribute('intens')
                annots = element.getAttribute('annots')

                # check mass and intensity
                try:
                    mass = float(mass)
                    intens = float(intens)
                except ValueError:
                    return False

                # add peak to peaklist
                peaklist.append([mass, intens, annots, 0])

        peaklist.sort()
        self.data['peaklist'] = peaklist
            
        return True
    # ----


    # ----
    def handleSequences(self, elements):
        """ Get sequence and modifications from <sequences> element. """

        # get sequence element
        element = elements.getElementsByTagName('sequence')
        if element:
            sequence = element[0]
        else:
            return False

        # get title
        element = sequence.getElementsByTagName('title')
        if element:
            self.data['seqtitle'] = self.getNodeText(element[0].childNodes)

        # get sequence
        element = sequence.getElementsByTagName('seq')
        if element:
            seq = self.getNodeText(element[0].childNodes)

            # check sequnce
            for amino in seq:
                if amino not in 'ACDEFGHIKLMNPQRSTVWY':
                    return False

            self.data['sequence'] = seq

        # get modifications
        element = sequence.getElementsByTagName('modifications')
        if element:
            self.handleModifications(element[0])
            
        return True
    # ----


    # ----
    def handleModifications(self, elements):
        """ Get modifications from <modifications> element. """

        # get modifications
        modifications = []
        for element in elements.childNodes:
            if element.nodeName == 'modification':

                # get attributes
                name = element.getAttribute('name')
                modType = element.getAttribute('type')
                position = element.getAttribute('position')
                amino = element.getAttribute('amino')
                gain = element.getAttribute('gain')
                loss = element.getAttribute('loss')

                # check data
                if not name or not modType or (not gain and not loss):
                    return False

                # format loss
                if loss:
                    loss = '-' + loss

                # residual modifications
                if modType == 'residual':
                    try:
                        position = int(position) - 1 # convert position to index
                        modifications.append([position, gain+loss, name])
                    except ValueError, TypeError:
                        return False

                # global modifications
                elif modType == 'global':
                    if not amino:
                        return False
                    else:
                        modifications.append([amino, gain+loss, name])

                # bad modType
                else:
                    return False

        self.data['modifications'] = modifications
            
        return True
    # ----


    # ----
    def handleSpectrum(self, elements):
        """ Get spectrum data from <spectrum> element. """

        mzArray = None
        intArray = None

        # get mzArray
        mzArrayBinary = elements.getElementsByTagName('mzArray')
        if mzArrayBinary:

            # get data
            mzArray = self.getNodeText(mzArrayBinary[0].childNodes)
            mzArrayCompress = mzArrayBinary[0].getAttribute('compression')

            # get endian
            mzEndian = '<'
            if mzArrayBinary[0].getAttribute('endian') == 'big':
                mzEndian = '>'

        # get intArray
        intenArrayBinary = elements.getElementsByTagName('intArray')
        if intenArrayBinary:

            # get data
            intArray = self.getNodeText(intenArrayBinary[0].childNodes)
            intArrayCompress = intenArrayBinary[0].getAttribute('compression')

            # get endian
            intEndian = '<'
            if intenArrayBinary[0].getAttribute('endian') == 'big':
                intEndian = '>'

        # check data
        if not mzArray or not intArray:
            return False

        # decode data
        try:
            mzData = base64.b64decode(mzArray)
            intData = base64.b64decode(intArray)
        except:
            return False

        # decompress data
        if mzArrayCompress == 'gz':
            try:
                mzData = zlib.decompress(mzData)
            except:
                return False
        if intArrayCompress == 'gz':
            try:
                intData = zlib.decompress(intData)
            except:
                return False

        # convert from binary format
        mzData = self.convertFromBinary(mzData, mzEndian)
        intData = self.convertFromBinary(intData, intEndian)

        # check data
        if not mzData or not intData or (len(mzData) != len(intData)):
            return False

        # "zip" mzData and intData
        spectrum = zip(mzData, intData)
        spectrum = map(list, spectrum)
        self.data['spectrum'] = spectrum
            
        return True
    # ----


    # ----
    def convertFromBinary(self, data, endian):
        """ Convert binary data to the list of values. """

        try:
          pointsCount = len(data)/struct.calcsize(endian+'f')
          start, end = 0, len(data)
          points = struct.unpack(endian+'f'*pointsCount, data[start:end])
          return points
        except:
            return None
    # ----


    # ----
    def getNodeText(self, nodelist):
        """ Get text from node list. """

        # get text
        buff = ''
        for node in nodelist:
            if node.nodeType == node.TEXT_NODE:
                buff += node.data

        return buff
    # ----
