#!/usr/bin/python

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
# See the COPYING file for license information.
#
# Copyright (c) 2011 Mellanox Technologies. All rights reserved

# Dump Mellanox devices multicast gid table using a Netlink interface.
# Both hardware and software tables are available.

# Author: Nir Muchtar <nirmu@mellanox.com>

import sys
import os
if os.path.exists('/usr/share/pyshared'):
    sys.path.append('/usr/share/pyshared')
if os.path.exists('/usr/lib/python2.7/site-packages'):
    sys.path.append('/usr/lib/python2.7/site-packages')
from netlink import Connection, NETLINK_GENERIC, U32Attr, NLM_F_DUMP
from genetlink import Controller, GeNlMessage

import struct
from socket import ntohl
from optparse import OptionParser
from collections import defaultdict

HASH_TABLE_SIZE = 4096
MLX4_MCG_CMD_HW_MGHT_GET = 1

ATTR_DEV_UNSPEC,\
ATTR_DEV_IDX,\
ATTR_MCG_IDX,\
ATTR_DOMAIN,\
ATTR_PRIO,\
ATTR_QP_COUNT,\
ATTR_QPS,\
ATTR_QP_IS_PROMISC,\
ATTR_PROT,\
ATTR_MGID,\
ATTR_ETHERTYPE,\
ATTR_VEP,\
ATTR_PORT,\
ATTR_BASE_FLAGS,\
ATTR_VLAN_ID,\
ATTR_MAC,\
ATTR_EXT_FLAGS,\
ATTR_DST_QPN,\
ATTR_SRC_IP,\
ATTR_DST_IP,\
ATTR_L4_PROTOCOL,\
ATTR_L4_SRC_PORT,\
ATTR_L4_DST_PORT, \
ATTR_MCG_NX_IDX,\
 = list(range(24))

prot_map = ["IB_IPv6", "ETH", "IB_IPv4", "FCOE"]
l4_prot_map = {0:"None", 3:"Other", 5:"UDP", 6:"TCP" }
domain_map = ["Ethtool", "RDMA CM", "Uverbs", "RFS", "mlx4_en"]

name = 1
fmt = 2
pr = 3
length = 4
mapping = 5

attr_map =  {
    ATTR_DEV_IDX : {name : "ATTR_DEV_IDX", fmt : "I", pr : "%d", length : 6},
    ATTR_MCG_IDX : {name : "INDEX", fmt : "I", pr : "%d", length : 6},
    ATTR_MCG_NX_IDX : {name : "NXIDX", fmt : "I", pr : "%d", length : 9},
    ATTR_DOMAIN : {name : "DOMAIN", fmt : "I", pr : "%s", length : 8, mapping : domain_map},
    ATTR_PRIO : {name : "PRIO", fmt : "I", pr : "%X", length : 5},
    ATTR_QP_COUNT : {name : "QPCNT", fmt : "I", pr : "%d", length : 6},
    ATTR_PROT : {name : "PROTO", fmt : "I", pr : "%s", length : 8, mapping : prot_map},
    ATTR_MGID : {fmt : ">HHHHHHHH", pr : "%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x", length : 50},
    ATTR_ETHERTYPE : {name : "ETYPE", fmt : ">H", pr : "%d", length : 6},
    ATTR_VEP : {name : "VEP", fmt : "B", pr : "%d", length : 6},
    ATTR_PORT : {name : "PORT", fmt : "B", pr : "%d", length : 6},
    ATTR_BASE_FLAGS : {name : "FLAGS", fmt : "B", pr : "%d", length : 6},
    ATTR_VLAN_ID : {name : "VLAN", fmt : ">H", pr : "%d", length : 6},
    ATTR_MAC : {name : "MAC", fmt : "BBBBBB", pr : "%02X:%02X:%02X:%02X:%02X:%02X", length : 20},
    ATTR_EXT_FLAGS : {name : "EFLAGS", fmt : "B", pr : "%d", length : 8},
    ATTR_DST_QPN : {name : "DQPN", fmt : ">I", pr : "0x%X", length : 10},
    ATTR_SRC_IP : {name : "SIP", fmt : "BBBB", pr : "%d.%d.%d.%d", length : 15},
    ATTR_DST_IP : {name : "DIP", fmt : "BBBB", pr : "%d.%d.%d.%d", length : 15},
    ATTR_L4_PROTOCOL : {name : "L4", fmt : "B", pr : "%s", length : 6, mapping : l4_prot_map},
    ATTR_L4_SRC_PORT : {name : "SPORT", fmt : ">H", pr : "%d", length : 6},
    ATTR_L4_DST_PORT : {name : "DPORT", fmt : ">H", pr : "%d", length : 6},
    }

def attr_get(entry, i):
    t = struct.unpack(attr_map[i][fmt], entry[i].data);
    if attr_map[i].get(mapping):
        return attr_map[i][mapping][t[0]]
    else:
        return t

title = "%-9s " % "INDEX"

def add_field(field):
    global title
    if attr_map[field].get(name) :
        fmt = "%-" + str(attr_map[field][length]) + "s"
        title += fmt %(attr_map[field][name])
    return field

def get_fields(entry, fields):
    global attr_map
    ret = ""
    for f in fields:
        s = attr_map[f][pr] %attr_get(entry.attrs, f)
        ret += s + " " * (attr_map[f][length] - len(s))
    return ret

def get_qps(entry):
    qp_count = attr_get(entry, ATTR_QP_COUNT)[0]
    qps = struct.unpack(">" + "I" * qp_count, entry[ATTR_QPS].data);
    pr = "QPs: " + "0x%-6X " * qp_count
    return pr % qps

class MCGNetlink(object):
    def __init__(self):
        self.connection = Connection(NETLINK_GENERIC)
        controller = Controller(self.connection)
        self.family_id = controller.get_family_id("mlx4_mcg")

    def send_mght_cmd(self, cmd, dev_idx):
        request = GeNlMessage(self.family_id, cmd=cmd,
                              attrs=[U32Attr(1, dev_idx)],
                              flags=NLM_F_DUMP)
        request.send(self.connection)

    def recv_mght(self):
        return GeNlMessage.recv(self.connection)


parser = OptionParser(usage="%prog -d <mlx_if_idx> [options]", version="%prog 1.0")

parser.add_option("-d", "--dev", dest="dev_idx", default=0, metavar="IDX", type="int",
                  help="The index of the Mellanox device for which to dump the table (d)")
parser.add_option("-v", "--verbose", dest="verbose", default=False, action="store_true",
                  help="if given, duplicates will not be gathered")

(options, args) = parser.parse_args()


cmd = MLX4_MCG_CMD_HW_MGHT_GET

mcg_nl = MCGNetlink()
mcg_nl.send_mght_cmd(cmd, options.dev_idx)

base_fields = []
eth_fields = []
ib_fields = []
extended_fields = []

base_fields.append(add_field(ATTR_PRIO))
base_fields.append(add_field(ATTR_PROT))


eth_fields.append(add_field(ATTR_VEP))
eth_fields.append(add_field(ATTR_PORT))
eth_fields.append(add_field(ATTR_BASE_FLAGS))
eth_fields.append(add_field(ATTR_VLAN_ID))
eth_fields.append(add_field(ATTR_ETHERTYPE))
eth_fields.append(add_field(ATTR_MAC))

ib_fields.append(add_field(ATTR_MGID))

done = False
first = True

lines = []
class Rule:
    def __init__(self, bucket, index):
        self.bucket = bucket
        self.index = index
        self.count = 0

    def dump(self, rule, last_bucket = None):
        if self.index < HASH_TABLE_SIZE:
            print((">%-7d  %s" % (self.bucket, rule)))
        else:
            if last_bucket != None and self.bucket != bucket:
                print((">%-7d  " % self.bucket))

            print((" %-7d  %s" % (self.index, rule)))

rules = defaultdict(Rule)
while not done:
    data = mcg_nl.recv_mght()
    for entry in data:
        if first:
            if entry.attrs.get(ATTR_SRC_IP):
                extended_fields.append(add_field(ATTR_EXT_FLAGS))
                extended_fields.append(add_field(ATTR_SRC_IP))
                extended_fields.append(add_field(ATTR_DST_IP))
                extended_fields.append(add_field(ATTR_L4_PROTOCOL))
                extended_fields.append(add_field(ATTR_L4_SRC_PORT))
                extended_fields.append(add_field(ATTR_L4_DST_PORT))
            print(title)
            first = False
        if len(entry.attrs) == 0 :
            done = True
            break

        line = ""
        line += get_fields(entry, base_fields)
        if entry.attrs.get(ATTR_MGID):
            line += get_fields(entry, ib_fields)
        else:
            line += get_fields(entry, eth_fields)
        if entry.attrs.get(ATTR_SRC_IP):
            line += get_fields(entry, extended_fields)

        index = int(attr_get(entry.attrs, ATTR_MCG_IDX )[0])
        if index < HASH_TABLE_SIZE:
            bucket = index

        line += get_qps(entry.attrs)


        if options.verbose:
            Rule(bucket, index).dump(line)
        else:
            if line not in rules:
                r = rules[line] = Rule(bucket, index)
                lines += [line]

            rules[line].count += 1

if not options.verbose:
    bucket = -1
    dups = ""
    for rule in lines:
        r = rules[rule]
        if r.count == 1:

            r.dump(rule, last_bucket = bucket)

            bucket = r.bucket
        else:
            dups += " %-7d* %s\n" % (r.count, rule)

    if len(dups):
        print("Duplicated MCGS:")
        print(dups)
