#!/usr/bin/python
# licensed under the GPL, by Adrian Likins <alikins@redhat.com>
#
#  useage:
#           ./graphgen.py > foo.dot
#           dot -Tps -o foo.ps foo.dot
#
#   --dbpath  		path to a rpmdb
#   --color		enable color output

import rpm
import sys
import getopt
import string
import os

def showHelp():
	helpstring = """
A .dot format file for use with graphviz will be 
printed to stdout. 

To use graphviz to generate a ps file of the
graph:

	dot -Tps -o foo.ps foo.dot

--dbpath            path to a rpmdb
--color             enable color output
--skippackages      list of packages to skip
--packages           read the list of packages to graph from stdin
"""
	print helpstring

def listPackages(db,packages):
    headerList = []
    list = []
    if "firstkey" in dir(db):
        key = db.firstkey()
        h = db[key]
    else:
        iterator = db.match(0)
        h = iterator.next()
        
    while (h):
        name = h['name']
        epoch = h['epoch']
        if epoch == None:
            epoch = ""
        version = h['version']
        release = h['release']


        if packages:
            if name in packages:
                headerList.append(h)
                list.append([name, version, release, epoch])
	else:
	    headerList.append(h)
	    list.append([name, version, release, epoch])        
        if "nextkey" in dir(db):
            key = db.nextkey(key)
            if key:
                h = db[key]
            else:
                h = None
        else:
            h = iterator.next()


    list.sort()
    return headerList

def getAllPackages(db):
    list = []

    headerList = listPackages(db,None)
    for i in headerList:
	print i['name']
	list.append(i['name'])
   
    return list
def findDepLocal(db, dep):
    header = None
    if dep[0] == '/':
        # Treat it as a file dependency
        try:
            hdr_arry = db.findbyfile(dep)
        except:
            hdr_arry = []

        for n in hdr_arry:
            header = db[n]
            type = 0
            break

    else:
        # Try it first as a package name
        try:
            hdr_arry = db.findbyname(dep)
        except:
            hdr_arry = []
        for n in hdr_arry:
            header = db[n]
            type = 1
            break
        else:
            # else try it as a soname
            try:
                hdr_arry = db.findbyprovides(dep)
            except:
                hdr_arry = []
            for n in hdr_arry:
                header = db[n]
                type = 2
                break

    if header != None:
        return (header,type)
    else:
        return (None,None)

def getBasePackages():
    input = sys.stdin.readlines()
    base_packages = map(lambda a: string.strip(a), input)
    return base_packages

def getPackages():
    input = sys.stdin.readlines()
    packages = map(lambda a: string.strip(a), input)
    return packages

def getDb(dbpath):
    rpm.addMacro("_dbpath", dbpath)
    db = rpm.opendb()
    return db

def getHeaders(db,packages=None):
    headerList = listPackages(db,packages)
    return headerList

def getRequires(headerList, packages, base_packages):
    requires = {}
    for header in headerList:
        name = header['name']
	if packages:
            if name not in packages:
                continue
	    
        foo =  header['requirename']
        requires[name] = foo
    return requires 

def getPackageRequires(db, requires, packages, base_packages):
    package_requires = {}
    for package_name in requires.keys():
        reqs = requires[package_name]
        package_requires[package_name] = []
        if reqs and len(reqs):
            for i in reqs:
                (pkg,type) = findDepLocal(db,i)
                if pkg:
                    pkgname = pkg['name']
                    # munge everything in "Base" together
                    if len(base_packages):
                        if not pkg['name'] in base_packages:
			    continue
                    package_requires[package_name].append((pkgname,type))

    for blip in package_requires.keys():
        new_list = []
        tmp_list = package_requires[blip]
        d = {}
        for i in tmp_list:
            d[i] = i
        new_list = d.values()
        package_requires[blip] = new_list

    return package_requires

def printHead():
    print "digraph redhat {"
    print "\trankdir=LR;"
    print "\tsize=\"8.5,11\""
    print "\tratio=fill"
    print "\tnode [fontsize=96];"  
#    print "\t concentrate=true"

def main():

    arglist = sys.argv[1:]
    try:
        optlist,arglist = getopt.getopt(arglist,
                                        'hd:bcps:',
                                        ['help',
                                         'dbpath=', 
                                         'basepackages',
                                         'color',
                                         'packages', 
                                         'skippackages=',
                                         'subgraphs='])
    except getopt.error, e:
        print "Error parsing command list arguments: %s" % e
        sys.exit(1)

    dbpath = "/var/lib/rpm"
    base_packages = []
    nocolor = 1
    skip_packages = ['libstdc++', 'compat-libstdc++', 'compat-glibc', 'glibc','bash']
    #skip_packages = ['compat-libstdc++','compat-glibc']
    # This is a list of packages. all packages that depend on these will be subgraphed together
    #subgraph_list = ['bash', 'XFree86-libs']
    subgraph_list = []
    style = "[style=filled]"  
    packages = None

    for opt in optlist:
        if opt[0] == '-h' or opt[0] == "--help":
            showHelp()
	    sys.exit()
        if opt[0] == '-d' or opt[0] == "--dbpath":
            dbpath = os.path.abspath(os.path.expanduser(opt[1]))
        if opt[0] == '-b' or opt[0] == "--basepackages":
            base_packages = getBasePackages()
        if opt[0] == '-c' or opt[0] == "--color":
            nocolor = 0
        if opt[0] == '-c' or opt[0] == "--skippackages":
            skippackages = string.split(opt[1], ',')
        if opt[0] == '-p' or opt[0] == "--packages":
            packages = getPackages()
        if opt[1] == "--subgraphs":
            subgraph_list = string.split(opt[1], ',')

    provides = {}
    db = getDb(dbpath)
    #installedPackages = getAllPackages(db)
#    if packages == None:
#	packages = installedPackages
#    print packages
    headerList = getHeaders(db, packages)
    requires = getRequires(headerList, packages, base_packages)

    # build up a dict of packagename -> ['req1','req2', 'req3']
    package_requires = getPackageRequires(db,requires,packages,base_packages)
    
    concentrators = {}
    for i in subgraph_list:
        concentrators[i] = []

    printHead()

    # calculate the total number of links
    total_links = 0
    link_pkgs = {}
    max_links = 0
    for package in package_requires.keys():
        if package in skip_packages:
            continue
#        if package not in pt_base_package:
#            continue
        list = package_requires[package]
        if len(list):
            for reqs in list:
                tmp = link_pkgs.get(reqs[0])
                if tmp:
                    tmp2 = tmp + 1
                else:
                    tmp2 = 1
                link_pkgs[reqs[0]] = tmp2

    for pkg in link_pkgs.keys():
        if pkg in skip_packages:
            continue
        blip = link_pkgs[pkg]
        if link_pkgs[pkg] > max_links:
            max_links = link_pkgs[pkg]

    
    for package in package_requires.keys():
        if package in skip_packages:
            continue

        list = package_requires[package]
        if len(list):
            for reqs in list:
                if reqs[0] in skip_packages:
                    continue
                #if reqs[0] in base_packages:
                #    req_pkg = reqs[0]
                if reqs[1] == 0:
		#    continue
                    # files
                    req_pkg = reqs[0]
                    color = "color=green,layer=\"files\""
                    weight = "weight = 1"
                elif reqs[1] == 1:
		#    continue
                    # packages
                    req_pkg = reqs[0]
                    color = "color=red,layer=\"packages\""
                    weight = "weight = 10"
                elif reqs[1] == 2:
                    # libs
                    req_pkg = reqs[0]
                    color = "color=blue,layer=\"libs\""
                    weight = "weight = 25"
                else:
                    req_pkg = reqs[0]
                    # shouldnt happen, so light grey
                    color = "[color = \"0.3 0.3 0.3\"]"
                if nocolor:
                    color=""
                else:
                    if reqs[0] in subgraph_list:
                        concentrators[reqs[0]].append(package)
            
                print "\t\"%s\" -> \"%s\" [%s,%s]; " % (package, req_pkg,color, weight)

# CONFIG
    node_color = ""
    alpha_hash = {}
    for pkg in link_pkgs.keys():
        style = "[arrowhead=normal]"
        if pkg in skip_packages:
            continue
        if pkg in base_packages:
            style = "[shape=diamond,style=bold]"
    
        foobar = float( float( link_pkgs[pkg] )/float(max_links))
        fsize = 96 + (link_pkgs[pkg] * 10)
        if fsize > 2048:
            fontsize = 2048
        elif fsize < 96:
            fontsize = 96
        else:
            fontsize = fsize

        print "\t\"%s\" [fontsize=%s] %s %s; // fsize: %s links_pkgs %s max_links: %s foobar: %s" % (pkg,
                                                                                                     fontsize,
                                                                                                     style,
                                                                                                     node_color,
                                                                                                     fsize,
                                                                                                     link_pkgs[pkg],

                                                                                                     max_links,
                                                                                                     foobar)


    for sub in concentrators.keys():
        print "\tsubgraph \"cluster_%s\" { label=\"%s\"; fontsize=512;style=\"setlinewidth(64)\";" % (sub,sub),
        for i in concentrators[sub]:
            print "\"%s\" ;" % i,
        print "}"

    print "}"
            
   

if __name__ == "__main__":
    main()


