|
4 | 4 | import unittest |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | | -from comp_sys import CompSys, IsNoPBC |
| 7 | +from comp_sys import CompLabeledSys, CompSys, IsNoPBC |
8 | 8 | from context import dpdata |
9 | 9 |
|
| 10 | +try: |
| 11 | + from ase.io import read, write |
| 12 | +except ModuleNotFoundError: |
| 13 | + skip_ase = True |
| 14 | +else: |
| 15 | + skip_ase = False |
| 16 | + |
10 | 17 |
|
11 | 18 | class TestToXYZ(unittest.TestCase): |
12 | 19 | def test_to_xyz(self): |
@@ -44,3 +51,138 @@ def setUp(self): |
44 | 51 | with tempfile.NamedTemporaryFile("r") as f_xyz: |
45 | 52 | self.system_1.to("xyz", f_xyz.name) |
46 | 53 | self.system_2 = dpdata.System(f_xyz.name, fmt="xyz") |
| 54 | + |
| 55 | + |
| 56 | +@unittest.skipIf(skip_ase, "skip ASE related test. install ASE to fix") |
| 57 | +class TestExtXYZASECrossCompatibility(unittest.TestCase): |
| 58 | + """Test cross-compatibility between dpdata extxyz and ASE extxyz.""" |
| 59 | + |
| 60 | + def test_extxyz_format_compatibility_with_ase_read(self): |
| 61 | + """Test that dpdata's extxyz format can be read by ASE.""" |
| 62 | + # Use existing test data that's known to work with dpdata extxyz parser |
| 63 | + test_file = "xyz/xyz_unittest.xyz" |
| 64 | + |
| 65 | + # First verify dpdata can read it |
| 66 | + multi_systems = dpdata.MultiSystems.from_file(test_file, fmt="extxyz") |
| 67 | + self.assertIsInstance(multi_systems, dpdata.MultiSystems) |
| 68 | + self.assertTrue(len(multi_systems.systems) > 0) |
| 69 | + |
| 70 | + # Test that ASE can also read the same file |
| 71 | + atoms_list = read(test_file, index=":", format="extxyz") |
| 72 | + self.assertIsInstance(atoms_list, list) |
| 73 | + self.assertTrue(len(atoms_list) > 0) |
| 74 | + |
| 75 | + # Check basic structure of first frame |
| 76 | + atoms = atoms_list[0] |
| 77 | + self.assertTrue(len(atoms) > 0) |
| 78 | + self.assertTrue(hasattr(atoms, "get_chemical_symbols")) |
| 79 | + |
| 80 | + def test_manual_extxyz_ase_to_dpdata(self): |
| 81 | + """Test cross-compatibility with a manually created compatible extxyz.""" |
| 82 | + # Create a manually written extxyz content that should work with both |
| 83 | + extxyz_content = """2 |
| 84 | +energy=-10.5 Lattice="5.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 5.0" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 |
| 85 | +C 0.0 0.0 0.0 6 0.1 0.1 0.1 |
| 86 | +O 1.0 1.0 1.0 8 -0.1 -0.1 -0.1 |
| 87 | +""" |
| 88 | + |
| 89 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: |
| 90 | + f.write(extxyz_content) |
| 91 | + f.flush() |
| 92 | + |
| 93 | + # Test with dpdata |
| 94 | + multi_systems = dpdata.MultiSystems.from_file(f.name, fmt="extxyz") |
| 95 | + self.assertIsInstance(multi_systems, dpdata.MultiSystems) |
| 96 | + self.assertTrue(len(multi_systems.systems) > 0) |
| 97 | + |
| 98 | + system_key = list(multi_systems.systems.keys())[0] |
| 99 | + system = multi_systems.systems[system_key] |
| 100 | + self.assertEqual(system.get_nframes(), 1) |
| 101 | + |
| 102 | + # Test with ASE (basic read) |
| 103 | + atoms = read(f.name, format="extxyz") |
| 104 | + self.assertEqual(len(atoms), 2) |
| 105 | + self.assertEqual(atoms.get_chemical_symbols(), ["C", "O"]) |
| 106 | + |
| 107 | + def test_dpdata_xyz_to_ase_basic(self): |
| 108 | + """Test basic xyz reading between dpdata and ASE (simple compatibility check).""" |
| 109 | + # Create a simple xyz file using dpdata's basic xyz format |
| 110 | + simple_system = dpdata.System( |
| 111 | + data={ |
| 112 | + "atom_names": ["C", "O"], |
| 113 | + "atom_numbs": [1, 1], |
| 114 | + "atom_types": np.array([0, 1]), |
| 115 | + "coords": np.array([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]]), |
| 116 | + "cells": np.zeros((1, 3, 3)), |
| 117 | + "orig": np.zeros(3), |
| 118 | + "nopbc": True, |
| 119 | + } |
| 120 | + ) |
| 121 | + |
| 122 | + with tempfile.NamedTemporaryFile(suffix=".xyz", mode="w+") as f: |
| 123 | + # Write basic xyz using dpdata |
| 124 | + simple_system.to("xyz", f.name) |
| 125 | + |
| 126 | + # Read with ASE |
| 127 | + atoms = read(f.name, format="xyz") |
| 128 | + |
| 129 | + # Verify basic structure |
| 130 | + self.assertEqual(len(atoms), 2) |
| 131 | + self.assertEqual(atoms.get_chemical_symbols(), ["C", "O"]) |
| 132 | + |
| 133 | + # Check positions |
| 134 | + np.testing.assert_allclose( |
| 135 | + atoms.get_positions(), [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], rtol=1e-6 |
| 136 | + ) |
| 137 | + |
| 138 | + |
| 139 | +@unittest.skipIf(skip_ase, "skip ASE related test. install ASE to fix") |
| 140 | +class TestExtXYZEnergyForceCompatibility(unittest.TestCase, CompLabeledSys): |
| 141 | + """Test energy and force preservation between dpdata and ASE using CompLabeledSys.""" |
| 142 | + |
| 143 | + def setUp(self): |
| 144 | + # Set precision for CompLabeledSys |
| 145 | + self.places = 6 |
| 146 | + self.e_places = 6 |
| 147 | + self.f_places = 6 |
| 148 | + self.v_places = 4 |
| 149 | + |
| 150 | + # Create a manually written extxyz content with known energies and forces |
| 151 | + extxyz_content = """2 |
| 152 | +energy=-10.5 Lattice="5.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 5.0" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 |
| 153 | +C 0.0 1.0 2.0 6 0.1 0.1 0.1 |
| 154 | +O 3.0 4.0 5.0 8 -0.1 -0.1 -0.1 |
| 155 | +""" |
| 156 | + |
| 157 | + # Write the extxyz content to a file |
| 158 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: |
| 159 | + f.write(extxyz_content) |
| 160 | + f.flush() |
| 161 | + self.temp_file = f.name |
| 162 | + |
| 163 | + # Read with dpdata - this is our reference system |
| 164 | + multi_systems = dpdata.MultiSystems.from_file(self.temp_file, fmt="extxyz") |
| 165 | + system_key = list(multi_systems.systems.keys())[0] |
| 166 | + self.system_1 = multi_systems.systems[system_key] |
| 167 | + |
| 168 | + # Read with ASE |
| 169 | + atoms = read(self.temp_file, format="extxyz") |
| 170 | + |
| 171 | + # Write back to extxyz with ASE |
| 172 | + with tempfile.NamedTemporaryFile(suffix=".xyz", mode="w+", delete=False) as f2: |
| 173 | + self.temp_file2 = f2.name |
| 174 | + write(f2.name, atoms, format="extxyz") |
| 175 | + |
| 176 | + # Read back the ASE-written file with dpdata |
| 177 | + roundtrip_ms = dpdata.MultiSystems.from_file(self.temp_file2, fmt="extxyz") |
| 178 | + system_key = list(roundtrip_ms.systems.keys())[0] |
| 179 | + self.system_2 = roundtrip_ms.systems[system_key] |
| 180 | + |
| 181 | + def tearDown(self): |
| 182 | + import os |
| 183 | + |
| 184 | + try: |
| 185 | + os.unlink(self.temp_file) |
| 186 | + os.unlink(self.temp_file2) |
| 187 | + except (OSError, AttributeError): |
| 188 | + pass |
0 commit comments