From 9090ab7a8f205a6fbb58178113e999ba65b89fa9 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Mon, 11 Mar 2024 17:12:58 +0100
Subject: [PATCH] [msh2vtu] allow multiple materialIDs per element

---
 ogstools/msh2vtu/__init__.py |  5 +--
 tests/test_msh2vtu.py        | 69 ++++++++++++++++++++++++++++++++++--
 2 files changed, 68 insertions(+), 6 deletions(-)

diff --git a/ogstools/msh2vtu/__init__.py b/ogstools/msh2vtu/__init__.py
index b24c5ef82..db5df4011 100644
--- a/ogstools/msh2vtu/__init__.py
+++ b/ogstools/msh2vtu/__init__.py
@@ -595,7 +595,6 @@ def msh2vtu(
         return 0
 
     for name, data in field_data.items():
-        ph_id = data[ph_index]  # selection by physical id (user defined)
         subdomain_dim = data[geo_index]  # 0 or 1 or 2 or 3
         if dim0 <= subdomain_dim and subdomain_dim <= dim3:
             subdomain_cell_types = existing_cell_types.intersection(
@@ -640,9 +639,7 @@ def msh2vtu(
 
         for cell_type in subdomain_cell_types:
             # cells
-            selection_index = (
-                cell_data_dict[gmsh_physical_cell_data_key][cell_type] == ph_id
-            )
+            selection_index = mesh.cell_sets_dict[name][cell_type]
             selection_cells_values = cells_dict[cell_type][selection_index]
             if len(selection_cells_values):  # if there are some data
                 selection_cells_block = (cell_type, selection_cells_values)
diff --git a/tests/test_msh2vtu.py b/tests/test_msh2vtu.py
index b51f9060e..a69c0f6c6 100644
--- a/tests/test_msh2vtu.py
+++ b/tests/test_msh2vtu.py
@@ -9,7 +9,10 @@ from itertools import product
 from pathlib import Path
 from unittest.mock import patch
 
+import gmsh
 import meshio
+import numpy as np
+import pyvista as pv
 
 from ogstools.meshlib import gmsh_meshing
 from ogstools.msh2vtu import msh2vtu
@@ -20,8 +23,70 @@ def test_cli():
     subprocess.run(["msh2vtu", "--help"], check=True)
 
 
+def test_multiple_groups_per_element(tmp_path: Path):
+    """Test correct conversion, if element are assigned to multiple groups."""
+    gmsh.initialize()
+    # gmsh.option.setNumber("General.Terminal", 1)
+    gmsh.model.add("multiple_groups_per_element")
+
+    gmsh.model.geo.addPoint(-5, -5, 0, 1)
+    gmsh.model.geo.addPoint(-2, -5, 0, 1)
+    gmsh.model.geo.addPoint(2, -5, 0, 1)
+    gmsh.model.geo.addPoint(5, -5, 0, 1)
+    gmsh.model.geo.addPoint(5, 5, 0, 1)
+    gmsh.model.geo.addPoint(-5, 5, 0, 1)
+
+    for i in range(1, 7):
+        gmsh.model.geo.addLine(i, i % 6 + 1, i)
+
+    gmsh.model.geo.addCurveLoop([1, 2, 3, 4, 5, 6], 1)
+    gmsh.model.geo.addPlaneSurface([1], 1)
+    gmsh.model.geo.mesh.setTransfiniteCurve(2, 20)
+    gmsh.model.geo.addPhysicalGroup(dim=2, tags=[1], name="domain")
+    gmsh.model.geo.addPhysicalGroup(dim=1, tags=[2], name="bottom_center")
+    gmsh.model.geo.addPhysicalGroup(dim=1, tags=[1, 2, 3], name="bottom")
+    gmsh.model.geo.addPhysicalGroup(dim=1, tags=[4], name="right")
+    gmsh.model.geo.addPhysicalGroup(dim=1, tags=[5], name="top")
+    gmsh.model.geo.addPhysicalGroup(dim=1, tags=[6], name="left")
+    gmsh.model.geo.addPhysicalGroup(
+        dim=1, tags=[1, 2, 3, 4, 5, 6], name="boundaries"
+    )
+
+    gmsh.model.geo.synchronize()
+    gmsh.model.mesh.generate(2)
+    model_name = "multiple_groups_per_element"
+    msh_file = Path(tmp_path, model_name + ".msh")
+    gmsh.write(str(msh_file))
+    gmsh.finalize()
+
+    assert msh2vtu(msh_file, tmp_path, output_prefix=model_name) == 0
+
+    def number_of_elements(boundary_name: str) -> int:
+        file = f"{model_name}_physical_group_{boundary_name}.vtu"
+        return pv.read(str(Path(tmp_path, file))).number_of_cells
+
+    assert number_of_elements("boundaries") == sum(
+        [
+            number_of_elements(name)
+            for name in ["left", "right", "top", "bottom"]
+        ]
+    )
+    bottom = pv.read(
+        str(Path(tmp_path, f"{model_name}_physical_group_bottom.vtu"))
+    )
+    bottom_center = pv.read(
+        str(Path(tmp_path, f"{model_name}_physical_group_bottom_center.vtu"))
+    )
+    assert np.all(
+        np.in1d(bottom_center["bulk_node_ids"], bottom["bulk_node_ids"])
+    )
+    assert np.all(
+        np.in1d(bottom_center["bulk_elem_ids"], bottom["bulk_elem_ids"])
+    )
+
+
 def test_rect(tmp_path: Path):
-    """Create rectangular gmsh meshes andconvert with msh2vtu."""
+    """Create rectangular gmsh meshes and convert with msh2vtu."""
     msh_file = Path(tmp_path, "rect.msh")
     permutations = product([1.0, 2.0], [1, 2], [True, False], [1, 2])
     for edge_length, n_edge_cells, structured, order in permutations:
@@ -36,7 +101,7 @@ def test_rect(tmp_path: Path):
 
 
 def test_cuboid(tmp_path: Path):
-    """Create rectangular gmsh meshes andconvert with msh2vtu."""
+    """Create rectangular gmsh meshes and convert with msh2vtu."""
     msh_file = Path(tmp_path, "cuboid.msh")
     permutations = product([1.0, 2.0], [1, 2], [True, False], [1, 2])
     for edge_length, n_edge_cells, structured, order in permutations:
-- 
GitLab