diff --git a/pre/BCs/inhomogeneous-mass-flux/prepare-inhomogeneous-mass-flux-BC.py b/pre/BCs/inhomogeneous-mass-flux/prepare-inhomogeneous-mass-flux-BC.py
index 4d32204582e1024299ed0401638f185e1b088196..d23de9fc8cbe3155f008154106181007d95db6dd 100755
--- a/pre/BCs/inhomogeneous-mass-flux/prepare-inhomogeneous-mass-flux-BC.py
+++ b/pre/BCs/inhomogeneous-mass-flux/prepare-inhomogeneous-mass-flux-BC.py
@@ -1,12 +1,19 @@
 #!/usr/bin/env pvpython
 # -*- coding: utf-8 -*-
 
+from __future__ import print_function
+
 import inspect
 
 import argparse
 parser = argparse.ArgumentParser()
 parser.add_argument("--input", type=str, required=True)
 parser.add_argument("--output", type=str, required=True)
+parser.add_argument("--profile", type=str, required=True)
+parser.add_argument("--total-flux", type=float, required=True)
+parser.add_argument("--slice-normal", type=float, nargs=3, required=True)
+parser.add_argument("--slice-origin", type=float, nargs=3, required=True)
+
 args = parser.parse_args()
 
 #### import the simple module from the paraview
@@ -62,8 +69,8 @@ slice1.Triangulatetheslice = 0
 slice1.SliceOffsetValues = [0.0]
 
 # init the 'Plane' selected for 'SliceType'
-slice1.SliceType.Origin = [0.0044, 0.03, 0.0]
-slice1.SliceType.Normal = [0.0, 1.0, 0.0]
+slice1.SliceType.Origin = args.slice_origin
+slice1.SliceType.Normal = args.slice_normal
 
 
 # Trick from http://www.vtk.org/Wiki/VTK/Examples/Cxx/PolyData/PolyDataToUnstructuredGrid
@@ -74,6 +81,8 @@ appendDatasets1 = AppendDatasets(Input=slice1)
 def do_compute_mass_flux():
     import numpy as np
     from scipy.interpolate import interp1d
+    from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
+
     data = self.GetInputDataObject(0, 0)
 
     ### compute predefined mass flux
@@ -81,21 +90,27 @@ def do_compute_mass_flux():
     mu  = 21.90e-6 # Pa s
     rho = 0.9333 # kg/m³
 
-    csv = "smooth_velocity_profile.csv"
+    csv = args.profile
     rs, velocities = np.loadtxt(csv, unpack=True, usecols=(0,1))
     velocity_fct = interp1d(rs, velocities)
 
-    points = data.GetPoints()
+    # compute total flux as "seen" by the discrete grid
+    rs_grid = vtk_to_numpy(data.GetPoints().GetData())[:,0]
+    vs_grid = velocity_fct(rs_grid)
 
-    mass_flux = vtk.vtkDoubleArray()
-    mass_flux.SetName("mass_flux")
-    mass_flux.SetNumberOfComponents(1)
-    N = points.GetNumberOfPoints()
-    mass_flux.SetNumberOfTuples(N)
+    mass_flux = vs_grid * rho
 
-    for i in range(N):
-        r, z, phi = points.GetPoint(i)
-        mass_flux.SetComponent(i, 0, velocity_fct(r) * rho)
+    total_flux = 2*np.pi * np.trapz(y=rs_grid * mass_flux, x=rs_grid)
+
+    scale_factor = args.total_flux / total_flux
+
+    print("total flux from given profile:", total_flux, "[e.g. kg/s]")
+    print("requested total flux:", args.total_flux, "kg/s")
+    print("scale factor:", scale_factor)
+
+    mass_flux *= scale_factor
+    mass_flux = numpy_to_vtk(mass_flux, 1)
+    mass_flux.SetName("mass_flux")
 
     self.GetOutputDataObject(0).GetPointData().AddArray(mass_flux)