#!/usr/bin/env python2
# SPDX-License-Identifier: GPL-2.0
# X-SPDX-Copyright-Text: (c) Solarflare Communications Inc

# Parse iptables to create onload's filter rules

import subprocess
import sys
import re
import os
import glob
import copy
from optparse import OptionParser

global unparsable_is_fatal
global chain_policy
global default_chain_policy
global use_verbose_iptables
global skips
chain_policy = None
default_chain_policy = "ACCEPT"
unparsable_is_fatal = True
program_version = "0.1"
use_verbose_iptables = True
skips = [ 'reject-with', ]

def StripComments(line):
	''' Return c-style comment separately. '''
	matches = re.match('.*\/\*(.*)\*\/.*', line)
	if matches:
		matches = str(matches.group(1))
	return ( matches, re.sub('/\*.*?\*/', '', line, re.S) )

def GenerateIptables():
	''' Open up a connection to iptables, and generate each line from it '''
	global use_verbose_iptables
	cmdline = ['iptables','-L','-n']
	if ( use_verbose_iptables ):
		cmdline.append( '-v' )

	p = subprocess.Popen( cmdline, stdout=subprocess.PIPE )
	while p.returncode is not None:
		# While it's alive, keep reading
		for line in p.stdout:
			yield data
	# And grab any data that it spat out before it died that we didn't read yet
	for line in p.stdout:
		yield line

def StripEntry(x):
	''' Just strip whitespace from each line of x '''
	for line in x:
		data = line.strip()
		# print "#", len(data), data
		if len(data) > 0:
			yield data

def SetChainPolicy(line):
	''' Turn "Chain INPUT (policy ACCEPT)" into "ACCEPT" '''
	global chain_policy
	global use_verbose_iptables
	
	policy_parts = line.split( "(policy " )
	if len(policy_parts) == 2:
		chain_policy = policy_parts[1].strip(")")
		if use_verbose_iptables:
			chain_policy = chain_policy.split()[0]

def ChainPolicy():
	''' Get the policy for the current chain '''
	global chain_policy
	global default_chain_policy
	action = default_chain_policy
	if chain_policy is not None:
		action = chain_policy
	return action

def GenerateDefaultPolicies(announce=False):
	''' Generate iptables rules that enforces the default chain policy '''
	action = ChainPolicy()
	if action == "ACCEPT":
		# No need to write extra accept rules, accept is the internal default
		pass
	else:
		if announce:
			print ( "# Default policy." )
		yield "%s udp  --  0.0.0.0/0.0.0.0 0.0.0.0/0.0.0.0 udp spts:0:65535 dpts:0:65535"%(action)
		yield "%s udp  --  0.0.0.0/0.0.0.0 0.0.0.0/0.0.0.0 udp spts:0:65535 dpts:0:65535"%(action)
		yield "%s tcp  --  0.0.0.0/0.0.0.0 0.0.0.0/0.0.0.0 tcp spts:0:65535 dpts:0:65535"%(action)

def FindChain(x, match="INPUT", announceDefaultPolicy=False):
	''' Yield all the lines for a chain, and handle the default chain policy '''
	matching = 0
	chain_terminated_ok = True
	for line in x:
		if "Chain " in line:
			# print "@@@", line
			if matching:
				# This is the start of a new chain, print the policy rule for the old chain
				for x in GenerateDefaultPolicies(announceDefaultPolicy):
					yield x
				chain_terminated_ok = True
			matching = match in line
			SetChainPolicy( line )
			chain_terminated_ok = False
		elif matching:
			yield line
	if not chain_terminated_ok:
		for x in GenerateDefaultPolicies(announceDefaultPolicy):
			yield x

def SkipHeaders( x ):
	''' Skip lines such as "target     prot opt source               destination" and yield the rest. '''

	for line in x:
		if "target" in line and "source" in line and "destination" in line:
			# skip this line
			pass
		else:
			yield line

def SkipFirst( x, skip=1 ):
	''' Skip the first 'skip' entries in 'x' and yield the rest. '''
	for line in x:
		if skip > 0:
			skip = skip-1
		else:
			yield line

def GenerateBits( line ):
	''' Split a line like: "DROP       tcp  --  anywhere             anywhere            tcp dpt:5205 flags:FIN,RST,URG/URG" into a map. '''
	global use_verbose_iptables
	global skips
	
	bits = line.split()
	#print "@bits", bits
	rval = { 'interface' : None }
	
	if use_verbose_iptables:
		if len(bits) < 7:
			raise ValueError( "Insuffcient arguments in rule." )
		# Ignore the packet and byte counts
		bits.pop(0)
		bits.pop(0)
		if bits[3] != '*' and bits[3] != 'any':
			rval['interface'] = bits[3]
		elif bits[4] != '*' and bits[4] != 'any':
			rval['interface'] = bits[4]
		bits.pop(3)
		bits.pop(3)
	
	if len(bits) < 5:
		raise ValueError( "Insuffcient arguments in rule." )
		# Note: iptables rules that do not have a target are valid, but don't affect the packet, so
		# we can safely ignore them this way too.
	# If there are few elements, or an ignored element, then we must not have a source rule provided
  # So assume 'all source ports' as a default rule.
	elif len(bits) < 7 or bits[5] in skips:
		rval['module'] = bits[1]
		rval['rule'] = ["spts:0:65535" ]
	else:
		rval['module'] = bits[5]
		rval['rule'] = [ bits[6] ]

	if len(bits) > 7:
		for i in xrange( 7, len(bits) ):
			rval['rule'].append( bits[i] )
	
	rval['action'] 		= bits[0]
	rval['protocol']	= bits[1]
	rval['option']		= bits[2]
	rval['source']		= bits[3]
	rval['dest']		= bits[4]

	return rval

def RuleParts( d ):
	''' Split a rule into its constituent bits, always returns at least two bits '''
	if len(d) < 1:
		return ['','']
	return d.split(':')

def GenerateRulesPieces(bits):
	''' Turn the rules into a generator '''
	#print "@grp", bits
	for piece in bits['rule']:
		#print "@p", piece
		yield piece

def FilterValidPieces(g):
	''' Raise an error on any rule pieces we don't understand, otherwise yield them as a map '''
	single_prefix = [ 'spt:', 'dpt:', ]
	range_prefix = [ 'spts:', 'dpts:', ]
	global skips
	yield_next = None
	for piece in g:
		num_yielded = 0
		s = piece.split(':')
		#print "piece=", piece, "s=",s, len(s)
		if yield_next:
			yield { 'prefix':yield_next, 'start':piece, 'end':piece }
			num_yielded = num_yielded + 1
			yield_next = None
		for prefix in single_prefix:
			if piece.find(prefix) == 0 and len(s) == 2:
				yield { 'prefix':prefix, 'start':s[1], 'end':s[1] }
				num_yielded = num_yielded + 1
		for prefix in range_prefix:
			if piece.find(prefix) == 0 and len(s) == 3:
				#print "@piece yield", s[1], s[2]
				yield { 'prefix':prefix, 'start':s[1], 'end':s[2] }
				num_yielded = num_yielded + 1
		for prefix in skips:
			if piece.find(prefix) == 0 and len(s) == 1:
				yield_next = prefix
				num_yielded = num_yielded + 1

		if num_yielded != 1:
			raise ValueError( "Could not parse rule fragment '%s' into a filter rule."%piece )

def MergePieces ( g ):
	''' Plug the rule pieces back together again as a single rule map '''
	empty = { 'local_min':0, 'local_max':None, 'remote_min':0, 'remote_max':None, }
	rval = empty
	for piece in g:
		#print "@G", piece
		if 'spt' in piece['prefix']:
			rval['remote_min'] = piece['start']
			rval['remote_max'] = piece['end']
		if 'dpt' in piece['prefix']:
			rval['local_min'] = piece['start']
			rval['local_max'] = piece['end']
		if 'reject-with' in piece['prefix']:
			rval['reject-with'] = piece['start']
	return rval

def GetBits( bits ):
	''' Take the 'rule' part of a map and return a map giving source and destination port ranges '''
	# interpret the pieces of bits['rule']
	#print "@b", bits
	return MergePieces(FilterValidPieces( GenerateRulesPieces( bits ) ))

def ComprehensibleIpRule( d, strict, swap ):
	''' Turn a map of iptable rule pieces into a map which includes all our fields for enforcment, or raise a ValueError if it cannot be done. '''
	global use_verbose_iptables
	understood = [ 'tcp', 'udp', '/*' ]
	# The common error case of no action provided raises a more comprehensible
	# error message from SysFsAction than the one that would be raised here - so check ift first
	SysFsAction( d['action'] )
	#print "ComprehensibleIpRule", strict
	#print "d['module']=", d['module']
	rval = { 'local_min':0, 'local_max':None, 'remote_min':0, 'remote_max':None, }
	if d['module'] in understood:
		rval = GetBits( d )
		if strict and ( d['dest'] != "0.0.0.0/0.0.0.0" or rval['remote_max'] is not None ) and rval['local_max'] is None:
			raise ValueError( "Strict Warning: Remote match specified, without local." )
		if strict and rval.get('reject-with') is not None:
			raise ValueError( "Strict Warning: reject-with specified, cannot support." )
	else:
		raise ValueError( "Unknown/Unsupported module '%s', expected one of %s"%(d['module'],str(understood)) )

	rval['local_ip'] = d['dest']
	rval['remote_ip'] = d['source']
	rval['action'] = d['action']
	rval['protocol'] = d['protocol']

	if use_verbose_iptables:
		rval['interface'] = d['interface']

	# If module hasn't provided them, set them to default
	if rval['local_max'] is None:
		rval['local_max'] = 65535
	if rval['remote_max'] is None:
		rval['remote_max'] = 65535

	# If we had an outgoing rule we are making an incoming one, flip it.
	if swap:
		tmp = copy.deepcopy(rval);
		rval['local_ip'] = tmp['remote_ip']
		rval['remote_ip'] = tmp['local_ip']
		rval['local_min'] = tmp['remote_min']
		rval['local_max'] = tmp['remote_max']
		rval['remote_max'] = tmp['local_max']
		rval['remote_min'] = tmp['local_min']

	return rval

def SysFsAction( action ):
	''' Convert 'DROP' to 'DECELERATE', keep 'ACCEPT' and raise ValueError on everything else '''
	action = action.upper()
	if action == "DROP" or action == "REJECT":
		return "DECELERATE"
	if action == "ACCEPT":
		return "ACCEPT"
	else:
		raise ValueError( "Unsupported action '%s'"%action )
		#return action

def GenerateInterfaces():
	p = subprocess.Popen( ['ls', '/proc/sys/net/ipv4/conf/'], stdout=subprocess.PIPE )
	while p.returncode is not None:
		# While it's alive, keep reading
		for line in p.stdout:
			yield data
	# And grab any data that it spat out before it died that we didn't read yet
	for line in p.stdout:
		yield line

def SplitInterfaces(g):
	ignored = ['lo', 'all', 'default']
	for line in g:
		for piece in line.split():
			if piece not in ignored:
				yield piece

def RuleFromInterface( g, rulemap ):
	global use_verbose_iptables
	
	for interface in g:
		if not use_verbose_iptables or not interface or not rulemap['interface'] or rulemap['interface'] == interface:
			yield "if=%s protocol=%s local_ip=%s local_port=%s-%s remote_ip=%s remote_port=%s-%s action=%s" % ( 
				interface,
				rulemap['protocol'],
				rulemap['local_ip'],
				rulemap['local_min'],
				rulemap['local_max'],
				rulemap['remote_ip'],
				rulemap['remote_min'],
				rulemap['remote_max'],
				SysFsAction( rulemap['action'] )
				)

def GetInterfaces( options ):
	if options.all:
		return SplitInterfaces( GenerateInterfaces() )
	else:
		return [ options.interface ]

def SwapSrcDest( bits ):
	rval = bits
	rval['source']		= bits['dest']
	rval['dest']		= bits['source']
	return rval

def MakeSysFsRulesFromLine( line, options ):
	''' Turn a line (from iptables -L) into SysFs format rules for each interface '''
	comment,good = StripComments(line)
	if comment:
		# It's ok to not have a valid rule if it's potentially a comment rule.
		try:
			bits = GenerateBits(good)
		except:
			return( comment, "" )
	else:
		# Otherwise, we do need to propogate exceptions
		bits = GenerateBits(good)
	rulemap = ComprehensibleIpRule( bits, options.strict, options.treat_as_outgoing )
	interfaces = GetInterfaces( options )
	rules = RuleFromInterface( interfaces, rulemap )
	return ( comment, rules )

def PrintFrom(x, options):
	''' Prints the rules, one at a time '''
	global unparsable_is_fatal
	failure_count = 0
	for line in x:
		print ( line )
		try:
			comment, rules = MakeSysFsRulesFromLine(line, options)
			sysfs = "\n".join( rules )
			print ( "=> %s" % sysfs )
			if comment:
				print ( "=> /* %s */" %comment )
		except Exception:
			e = sys.exc_info()[1]
			sys.stderr.write( "=> Error parsing: %s\n"%str(e) )
			if ( unparsable_is_fatal ):
				return 1
			else:
				# We are permitted to swallow errors.
				failure_count = failure_count + 1
				pass
	return failure_count


def PrintInstalledRules(filename=None):
	''' Print out the current rules '''
	if filename is None:
		filename = '*'
	for filename in glob.glob(os.path.join('/proc/driver/sfc_resource/', filename, 'firewall_rules')):
		try:
			fp = open(filename, 'r')
			rules = fp.read()
			fp.close()
		except IOError:
			print "Cannot find the file", filename
			sys.exit(1)
		print(rules)


def OutputFilterRulesFrom(x, options):
	''' Builds up all the rules, in minimal format, so that we can fail the whole set if an error occurs '''
	global unparsable_is_fatal
	all_rules = []
	failure_count = 0
	for line in x:
		try:
			comment, rules = MakeSysFsRulesFromLine(line, options)
			all_rules.extend( rules )
			if comment:
				# A comment is ALWAYS something we can skip
				sys.stderr.write( "Comment skipped: %s\n"%comment )
			pass
		except Exception:
			e = sys.exc_info()[1]
			sys.stderr.write( "Error parsing '%s': %s\n"%(line,str(e)) )
			if ( unparsable_is_fatal ):
				return ( "", -1 )
			else:
				# silently swallow this rule
				failure_count = failure_count + 1
				pass
	return ( all_rules, failure_count )

def ParseCommandLineOptions():
	''' Parse the Commnadline options. '''
	global unparsable_is_fatal
	global use_verbose_iptables
	parser = OptionParser( version=" ".join( ("%prog", "version", program_version) ),
						   description="Turn iptables rules into Onload filter rules.\n"
						   "Return value will be the number of lines that could not be parsed before execution stopped.\n" )
	parser.add_option("-v", "--verbose",
					  action="store_true", dest="verbose", default=False,
					  help="Output longer list of which iptables rules were parsed, and which were ignored.")
	parser.add_option("-c", "--continue",
					  action="store_false", dest="unparsable_is_fatal", default=True,
					  help="Continue even if rules are found which we cannot enforce (verbose sets this automatically).\nDefault behaviour is to output no rules, rather than a partial rule set.")
	parser.add_option("-f", "--file",
					  action="store", type="string", dest="filename",
					  help="Read iptables from the given file rather than by invoking iptables.\nInput is in the same format as that returned by iptables -L -n (-v), including chain names.",
					  metavar="FILE" )
	parser.add_option("-s", "--strict", 
					  action="store_true", dest="strict", default=False,
					  help="Do not insert rules which cannot be enforced exactly.")
	parser.add_option("--interactive",
					  action="store_true", dest="interactive", default=False,
					  help="Read iptables from stdin rather than by invoking iptables.\nInput is in the same format as that returned by iptables -L -n (-v), including chain names.\nCtrl-D twice to exit." )
	parser.add_option("-i", "--interface", action="store", dest="interface", help="Write out directly to /proc/driver/sfc_resource/ for this interface", metavar="[ethX]" )
	parser.add_option("-a", "--all", action="store_true", dest="all", default=False, help="Write out to /proc/driver/sfc_resource/ for all interfaces.  Note that this can include non-solarflare interfaces; but the rules here will not take effect (as onload only supports solarflare cards) until a solarflare card is given that name." )
	parser.add_option("-o", "--output", action="store", dest="output", help="Output to the given file.  Note: If you wish to later send this file to /proc/driver/sfc_resource/firewall_add then you will need to ensure that you do not exceed the /proc filesystem block limit (typically 1024 bytes); or that you flush the writes between lines, rather than risking a rule getting broken due to an unfortunate split.", metavar="FILE" )
	parser.add_option("--report", action="store_true", dest="report", help="Report the currently installed rules." )
	parser.add_option("--chain", action="store", dest="chain", default="INPUT",
					   help="Which iptables chain to parse the rules from.  Default is INPUT." )
	parser.add_option("--use-extended", action="store_true", dest="use_verbose_iptables", default=False,
				help="Use extended IPtables (-L -n -v, rather than just -L -n) input, which includes which interface each rule is applied to.  Disabled by default, as the format may vary." )
	parser.add_option("--treat-as-outgoing", action="store_true", dest="treat_as_outgoing", default=False,
				help="Treat the chain as an OUTPUT chain, and so treat its dest rules as remote hosts.\Note: Onload can ONLY block incoming packets; but this is often suffcieint to prevent the connection." )
	(options, args) = parser.parse_args()
	unparsable_is_fatal = options.unparsable_is_fatal and not options.verbose
	use_verbose_iptables = options.use_verbose_iptables
	
	if options.output and options.verbose:
		# Verbose uses a different print method.  Standard unix redirection will work fine here anyway.
		parser.error("options --output and --verbose are mutually exclusive")
	if options.filename and options.interactive:
		parser.error("options --file and --interactive are mutually exclusive")
	if options.interface and options.all:
		parser.error("options --interface and --all are mutually exclusive")
	return options, args

def SelectSource(options):
	''' Apply commandline option to open up iptables, or a source file or stdin. '''
	if options.filename:
		return open( options.filename, "r" )
	elif options.interactive:
		return sys.stdin
	return GenerateIptables()

def del_existing_rules( options ):
	for interface in GetInterfaces(options):
		try:
			file = open( "/proc/driver/sfc_resource/firewall_del", "w" )
			file.write( "%s all\n"%interface )
			file.flush()
			file.close()
		except Exception:
			e = sys.exc_info()[1]
			sys.stderr.write( "Error deleting firewall rules for %s: %s" % (str(interface),str(e)) )
			pass

def main():
	''' Main entry point '''
	options, args = ParseCommandLineOptions()
	#print "@opt", options

	if options.report:
		PrintInstalledRules(options.interface)
		return 0
 	
	try:
		source = SelectSource(options)
	except Exception:
		e = sys.exc_info()[1]
		sys.stderr.write( "Error: %s\n"%str(e) )
		return -2
	
	generator = SkipHeaders(FindChain(StripEntry(source), options.chain, options.verbose))

	if options.verbose:
		return PrintFrom( generator, options )
	else:
		output, rval = OutputFilterRulesFrom( generator, options )
		file_needs_closing = False
		if ( rval < 0 ):
			return rval
		file = sys.stdout
		try:
			if options.output:
				file = open( options.output, "w" )
				file_needs_closing = True
			elif options.interface or (options.all and not options.verbose):
				del_existing_rules( options )
				file = open( "/proc/driver/sfc_resource/firewall_add", "w" )
				file_needs_closing = True
		except Exception:
			e = sys.exc_info()[1]
			sys.stderr.write( "Error: %s\n"%str(e) )
			return -3
		
		# There's a somewhat nasty limit of 1024 characters for a /proc file
		# To prevent us hitting that limit, make sure we flush after each line.
		for rule in output:
			try:
				file.write( "%s\n"%rule )
				file.flush()
			except Exception:
				e = sys.exc_info()[1]
				sys.stderr.write( "Error %s when writing rule %s" % (str(e),str(rule)) )
		
		if file_needs_closing:
			try:
				file.flush()
				file.close()
			except Exception:
				e = sys.exc_info()[1]
				sys.stderr.write( "Error %s when closing file" % str(e) )
		return rval

if __name__ == "__main__":
	rval = main()
	if ( rval > 128 ):
		sys.exit(128)
	if ( rval < -127 ):
		sys.exit(-127)
	sys.exit( rval )

