import os, time, threading
import psutil
from collections import namedtuple
import system_info_aix
import whatap.util.process_util as process_util
import whatap.util.thread_util as thread_util
import whatap.util.logging_util as logging_util
import whatap.agent.conf.configure as conf
from StringIO import StringIO
from multiprocessing import Pool, TimeoutError as _TimeoutError

from . import _DiskPerf

filesystem_blacklist = [
    "procfs"
]
def iterateLines(lines, skip=0):
    for i, l in enumerate(lines.split('\n')):
        l = l.strip()
        if i < skip or not l:
            continue
        
        yield l.split()

_mountPointCacheEx = None
def parseFromMountPointEx(mountPoint):
    global _mountPointCacheEx

    if _mountPointCacheEx is None or mountPoint not in _mountPointCacheEx:
        try:
            lspv = process_util.executeCommandShellWithTimeout("lspv", timeout = 1)
            if lspv == "":
                return None

            if _mountPointCacheEx == None:
                _mountPointCacheEx = {}

            for l in lspv.split('\n'):
                l = l.strip()
                if not l or len(l.split()) != 4:
                    logging_util.info("lspv skip line: ", l)
                    continue
                (deviceId, serial, vgid, stauts)= l.split()
                if 'none' == vgid.lower():
                    continue

                lspvlist = process_util.executeCommandShellWithTimeout('lsvg -l %s'%(vgid), timeout = 1)
                if lspvlist == "":
                    continue

                for tokens in iterateLines(lspvlist, skip=2):
                    if len(tokens) != 7:
                        logging_util.info("lsvg skip line: ")
                        continue
                    lvname, lvtype, lvlps, lvpps, lvpvs, lvstate, mountpoint = tokens
                    if 'N/A' == mountpoint:
                        continue
                    if mountpoint not in _mountPointCacheEx:
                        _mountPointCacheEx[mountpoint] = []    
                    _mountPointCacheEx[mountpoint].append(deviceId)    
            logging_util.info("Init Disk LV Info : {}".format(str(_mountPointCacheEx)))
        except Exception, e:
            logging_util.error("Disk Mount Pointer Check Error:", e)
            return None

    return _mountPointCacheEx.get(mountPoint)

sdiskio = namedtuple('sdiskio', ['read_count', 'write_count',
                                 'read_bytes', 'write_bytes',
                                 'read_time', 'write_time', 'time'])

class _DiskIoPerfCache:
    def __init__(self, iocounters=None):
        self.iocounters= iocounters
        self.timestamp = time.time()

    def contains(self, mountPoint):
        deviceIds = parseFromMountPointEx(mountPoint)
        deviceFound = False
        if not deviceIds:
            return deviceFound
        for deviceId in deviceIds:
            deviceFound = deviceFound or deviceId in self.iocounters
        return deviceFound
    
    def get(self, mountPoint):
        deviceIds = parseFromMountPointEx(mountPoint)
        iotime = 0
        read_count = 0
        write_count = 0
        read_bytes = 0
        write_bytes = 0
        read_time = 0
        write_time = 0

        if deviceIds:
            for deviceId in deviceIds:
                iocounter = self.iocounters.get(deviceId)
                iotime += iocounter.time
                read_count += iocounter.read_count
                write_count += iocounter.write_count
                read_bytes += iocounter.read_bytes
                write_bytes += iocounter.write_bytes
                read_time += iocounter.read_time
                write_time += iocounter.write_time
            iotime /= len(deviceIds)
        
        return sdiskio(time=iotime, read_count=read_count, write_count=write_count, read_bytes=read_bytes, write_bytes=write_bytes, read_time=read_time, write_time=write_time)

iocountersLastTime = _DiskIoPerfCache(psutil.disk_io_counters(perdisk=True))

def populateIoPerf(diskQuotas, init):
    global iocountersLastTime 

    iocounters = _DiskIoPerfCache(psutil.disk_io_counters(perdisk=True))
    diskQuotasThisTime = []
    for diskQuota in diskQuotas:
        diskQuotasThisTime.append(diskQuota)
        mountPoint = diskQuota.mountPoint

        if not iocounters.contains(mountPoint):
            if init:
                logging_util.info("Init Disk IO Info : {} not found iocounter".format(mountPoint))
            continue

        if not iocountersLastTime.contains(mountPoint):
            continue

        iocounter = iocounters.get(mountPoint)
        iocounterLasttime =  iocountersLastTime.get(mountPoint)
        timediff = float(iocounters.timestamp - iocountersLastTime.timestamp)

        diskQuota.readIops = float(iocounter.read_count - iocounterLasttime.read_count) / timediff
        diskQuota.writeIops = float(iocounter.write_count - iocounterLasttime.write_count) / timediff
        diskQuota.readBps = float(iocounter.read_bytes - iocounterLasttime.read_bytes) /timediff
        diskQuota.writeBps = float(iocounter.write_bytes - iocounterLasttime.write_bytes) /timediff
        diskQuota.ioPercent = float(100) * float(iocounter.time - iocounterLasttime.time) / (timediff * 1000)
    iocountersLastTime = iocounters 

    return diskQuotasThisTime


mountpointStatus = {} #True running
lock = threading.Lock()

class StatVfsError(Exception):
    pass

class StatVfsTimeout(Exception):
    pass

def statvfs(mountpoint, result):
    try :
        result['statvfs'] = os.statvfs(mountpoint)
    except Exception as e:
        result['error'] = e
    finally:
        mountpointStatus[mountpoint] = False
    #return os.statvfs(mountpoint)


def runStatvfsThread(mountpoint):
    result = {}
    with lock:
        if mountpoint in mountpointStatus and mountpointStatus[mountpoint] == True:
            raise StatVfsTimeout("statvfs hang : ", mountpoint)
        
        mountpointStatus[mountpoint] = True
    th = threading.Thread(target=statvfs, args=(mountpoint, result))
    th.daemon = True
    th.start()
    th.join(1)

    if th.is_alive():
        raise StatVfsTimeout("statvfs timeout: ", mountpoint)
    elif 'error' in result:
        raise StatVfsError("{}".format(result['error']))
    elif 'statvfs' in result:
        return result['statvfs']

    raise StatVfsError("statvfs unknown error")


def updateDiskPerf(init=False):
    global lastDiskPerf

    diskQuotas = []
    for part in psutil.disk_partitions(all=False):
        if init:
            logging_util.info("Init Disk Partiton Info : device - {}, mountpoint - {}, fstype - {}".format(part.device, part.mountpoint, part.fstype))
        if part.fstype in filesystem_blacklist:
            continue
        p = _DiskPerf()
        p.mountPoint = part.mountpoint
        if not p.mountPoint:
            p.mountPoint = part.device
        p.deviceId = part.device
        p.fileSystem = part.fstype

        try:
            #f_bsize, f_frsize, f_blocks, f_bfree, f_bavail, f_files, f_ffree, f_favail, f_flag, f_namemax = os.statvfs(p.mountPoint)
            f_bsize, f_frsize, f_blocks, f_bfree, f_bavail, f_files, f_ffree, f_favail, f_flag, f_namemax = runStatvfsThread(p.mountPoint)
        except Exception, e:
            logging_util.error("Disk statvfs error : {} - ".format(p.mountPoint), e)
            continue

        p.totalSpace = int(f_frsize * f_blocks)
        p.usedSpace = int(f_frsize * (f_blocks - f_bfree))
        p.usedPercent = float(100.0 * float(f_blocks - f_bfree) / float(f_blocks))
        p.freeSpace = p.totalSpace - p.usedSpace
        p.freePercent = float(100.0 * f_bfree / f_blocks)
        p.blksize = int(f_frsize)
        p.inodePercent = float(100.0 * float(f_files - f_ffree) / float(f_files))

        diskQuotas.append(p)
        
    lastDiskPerf= populateIoPerf(diskQuotas, init)

lastDiskPerf = None
first = True
def getDisk():
    global first
    if first:
        try:
            updateDiskPerf(init=True)    
            thread_util.asyncLoop(updateDiskPerf, conf.GetConfig().DiskstatInterval)
        except Exception, e:
            logging_util.error("disk perf aix error:", e)
        first = False
        return None
    global lastDiskPerf
    return lastDiskPerf



oldIOTotal = None
def GetDiskTotalIO():
    disk = psutil.disk_io_counters()
    now = time.time()

    global oldIOTotal
    if not oldIOTotal:
        oldIOTotal = [disk, now]
        return None

    oldDisk = oldIOTotal[0]
    diffTime = float(now - oldIOTotal[1])

    ret = _DiskPerf()

    ret.readBps = float(disk.read_bytes - oldDisk.read_bytes) / diffTime
    ret.writeBps = float(disk.write_bytes - oldDisk.write_bytes) / diffTime
    ret.readIops = float(disk.read_count - oldDisk.read_count) / diffTime
    ret.writeIops = float(disk.write_count - oldDisk.write_count) / diffTime

    oldIOTotal = [disk, now]
    return ret

oldIO = None
def GetDiskIO():
    disk = psutil.disk_io_counters(perdisk=True)
    now = time.time()

    global oldIO
    if not oldIO:
        oldIO = [disk, now]
        return None

    oldDisk = oldIO[0]
    diffTime = float(now - oldIO[1])

    retArr = []
    for k, v in disk.items():
        if k not in oldDisk or k.startswith("cd"):
            continue
        oldv = oldDisk[k]

        ret = _DiskPerf()
        ret.deviceId = k
        ret.readBps = float(v.read_bytes - oldv.read_bytes) / diffTime
        ret.writeBps = float(v.write_bytes - oldv.write_bytes) / diffTime
        readIoCount  = float(v.read_count - oldv.read_count)
        ret.readIops = readIoCount / diffTime
        writeIoCount = float(v.write_count - oldv.write_count)
        ret.writeIops = writeIoCount / diffTime

        readTime = float(v.read_time - oldv.read_time)
        writeTime = float(v.write_time - oldv.write_time)

        busyTime = float(v.time - oldv.time)
        if readIoCount == 0:
            ret.readTime = 0
        else:
            ret.readTime = readTime / readIoCount

        if writeIoCount == 0:
            ret.writeTime = 0
        else:
            ret.writeTime = writeTime / writeIoCount
        
        ret.ioPercent = 100 * (busyTime) / (diffTime * 1000)
        if ret.ioPercent > float(100):
            ret.ioPercent = float(100)

        retArr.append(ret)

    oldIO = [disk, now]
    return retArr

        
        

def test():
    for a in getDisk():
        print a

if __name__ == '__main__':
    test()
