Skip to content

Commit 7d3b6fe

Browse files
authored
Merge pull request #253 from SMTG-Bham/fix-bandstats
Fix band stats k-point labels
2 parents 9ec6740 + 5c0d625 commit 7d3b6fe

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

sumo/cli/bandplot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def bandplot(
295295
if code == "vasp":
296296
for vr_file in filenames:
297297
vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
298-
bs = vr.get_band_structure(line_mode=True)
298+
bs = vr.get_band_structure(line_mode=True, efermi="smart")
299299
bandstructures.append(bs)
300300
bs = get_reconstructed_band_structure(bandstructures)
301301
elif code == "castep":

sumo/cli/bandstats.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,26 @@ def bandstats(
9494
bandstructures = []
9595
for vr_file in filenames:
9696
vr = BSVasprun(vr_file, parse_projected_eigen=False)
97-
bs = vr.get_band_structure(line_mode=True)
97+
bs = vr.get_band_structure(line_mode=True, efermi="smart")
9898
bandstructures.append(bs)
99-
bs = get_reconstructed_band_structure(bandstructures, force_kpath_branches=False)
99+
bs, kpt_mapping = get_reconstructed_band_structure(
100+
bandstructures, force_kpath_branches=True, return_forced_branch_kpt_map=True
101+
)
100102

101103
if bs.is_metal():
102104
logging.error("ERROR: System is metallic!")
103105
sys.exit()
104106

105-
_log_band_gap_information(bs)
107+
_log_band_gap_information(bs, kpt_mapping=kpt_mapping)
106108

107109
vbm_data = bs.get_vbm()
108110
cbm_data = bs.get_cbm()
109111

110112
logging.info("\nValence band maximum:")
111-
_log_band_edge_information(bs, vbm_data)
113+
_log_band_edge_information(bs, vbm_data, kpt_mapping=kpt_mapping)
112114

113115
logging.info("\nConduction band minimum:")
114-
_log_band_edge_information(bs, cbm_data)
116+
_log_band_edge_information(bs, cbm_data, kpt_mapping=kpt_mapping)
115117

116118
if parabolic:
117119
logging.info("\nUsing parabolic fitting of the band edges")
@@ -179,11 +181,14 @@ def bandstats(
179181
return {"hole_data": hole_data, "electron_data": elec_data}
180182

181183

182-
def _log_band_gap_information(bs):
184+
def _log_band_gap_information(bs, kpt_mapping=None):
183185
"""Log data about the direct and indirect band gaps.
184186
185187
Args:
186188
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
189+
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
190+
band structure with forced branches to the original band structure.
191+
187192
"""
188193
bg_data = bs.get_band_gap()
189194
if not bg_data["direct"]:
@@ -199,6 +204,7 @@ def _log_band_gap_information(bs):
199204
direct_kpoint = bs.kpoints[direct_kindex].frac_coords
200205
direct_kpoint = kpt_str.format(k=direct_kpoint)
201206
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
207+
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
202208
k_indices = ", ".join(map(str, eq_kpoints))
203209

204210
# add 1 to band indices to be consistent with VASP band numbers.
@@ -215,7 +221,9 @@ def _log_band_gap_information(bs):
215221

216222
direct_kindex = direct_data[Spin.up]["kpoint_index"]
217223
direct_kpoint = kpt_str.format(k=bs.kpoints[direct_kindex].frac_coords)
218-
k_indices = ", ".join(map(str, bs.get_equivalent_kpoints(direct_kindex)))
224+
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
225+
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
226+
k_indices = ", ".join(map(str, eq_kpoints))
219227
b_indices = ", ".join(
220228
[str(i + 1) for i in direct_data[Spin.up]["band_indices"]]
221229
)
@@ -225,14 +233,16 @@ def _log_band_gap_information(bs):
225233
logging.info(f" Band indices: {b_indices}")
226234

227235

228-
def _log_band_edge_information(bs, edge_data):
236+
def _log_band_edge_information(bs, edge_data, kpt_mapping=None):
229237
"""Log data about the valence band maximum or conduction band minimum.
230238
231239
Args:
232240
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
233241
The band structure.
234242
edge_data (dict): The :obj:`dict` from ``bs.get_vbm()`` or
235243
``bs.get_cbm()``
244+
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
245+
band structure with forced branches to the original band structure.
236246
"""
237247
if bs.is_spin_polarized:
238248
spins = edge_data["band_index"].keys()
@@ -247,7 +257,9 @@ def _log_band_edge_information(bs, edge_data):
247257

248258
kpoint = edge_data["kpoint"]
249259
kpoint_str = kpt_str.format(k=kpoint.frac_coords)
250-
k_indices = ", ".join(map(str, edge_data["kpoint_index"]))
260+
k_indices = ", ".join(
261+
map(str, _map_kpoints(edge_data["kpoint_index"], kpt_mapping))
262+
)
251263
k_degen = bs.get_kpoint_degeneracy(kpoint=kpoint.frac_coords)
252264

253265
if kpoint.label:
@@ -311,6 +323,13 @@ def _log_effective_mass_data(data, is_spin_polarized, mass_type="m_e"):
311323
logging.info(f" {mass_type}: {eff_mass:.3f} | {band_str} | {kpoint_str}")
312324

313325

326+
def _map_kpoints(kpt_idxs, kpt_mapping):
327+
"""Map k-point indices to the original band structure."""
328+
if not kpt_mapping:
329+
return kpt_idxs
330+
return sorted(set([kpt_mapping.get(k, k) for k in kpt_idxs]))
331+
332+
314333
def _get_parser():
315334
parser = argparse.ArgumentParser(
316335
description="""

sumo/electronic_structure/bandstructure.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def get_projections(bs, selection, normalise=None):
190190
return spec_proj
191191

192192

193-
def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=True):
193+
def get_reconstructed_band_structure(
194+
list_bs, efermi=None, force_kpath_branches=True, return_forced_branch_kpt_map=False
195+
):
194196
"""Combine a list of band structures into a single band structure.
195197
196198
This is typically very useful when you split non self consistent
@@ -210,12 +212,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
210212
across all band structures is used.
211213
force_kpath_branches (bool): Force a linemode band structure to contain
212214
branches by adding repeated high-symmetry k-points in the path.
215+
return_forced_branch_kpt_map (bool): If True, return a mapping of the
216+
the new k-points to the original k-points.
213217
214218
Returns:
215219
:obj:`pymatgen.electronic_structure.bandstructure.BandStructure` or \
216220
:obj:`pymatgen.electronic_structure.bandstructureBandStructureSymmLine`:
217221
A band structure object. The type depends on the type of the band
218222
structures in ``list_bs``.
223+
If return_forced_branch_kpt_map is True, then a tuple is returned
224+
containing the band structure and the mapping from the new k-points
225+
to the original k-points.
219226
"""
220227
if efermi is None:
221228
efermi = sum(b.efermi for b in list_bs) / len(list_bs)
@@ -244,13 +251,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
244251
structure=list_bs[0].structure,
245252
projections=projections,
246253
)
247-
if force_kpath_branches:
248-
return force_branches(bs)
249-
else:
250-
return bs
254+
branch_bs, mapping = force_branches(bs, return_mapping=True)
255+
if force_kpath_branches and return_forced_branch_kpt_map:
256+
return branch_bs, mapping
257+
elif force_kpath_branches:
258+
return branch_bs
259+
elif return_forced_branch_kpt_map:
260+
return bs, mapping
261+
return bs
251262

252263

253-
def force_branches(bandstructure):
264+
def force_branches(bandstructure, return_mapping=False):
254265
"""Force a linemode band structure to contain branches.
255266
256267
Branches give a specific portion of the path from one high-symmetry point
@@ -262,9 +273,14 @@ def force_branches(bandstructure):
262273
263274
Args:
264275
bandstructure: A band structure object.
276+
return_mapping: If True, return a mapping of the new k-points (with branches)
277+
to the original k-points.
265278
266279
Returns:
267-
A band structure with brnaches.
280+
A band structure with branches.
281+
If return_forced_branch_kpt_map is True, then a tuple is returned
282+
containing the band structure and the mapping from the new k-points
283+
to the original k-points.
268284
"""
269285
kpoints = np.array([k.frac_coords for k in bandstructure.kpoints])
270286
labels_dict = {k: v.frac_coords for k, v in bandstructure.labels_dict.items()}
@@ -275,6 +291,7 @@ def force_branches(bandstructure):
275291
# already.
276292
dup_ids = []
277293
high_sym_kpoints = tuple(map(tuple, labels_dict.values()))
294+
mapping = {}
278295
for i, k in enumerate(kpoints):
279296
dup_ids.append(i)
280297
if (
@@ -287,6 +304,7 @@ def force_branches(bandstructure):
287304
)
288305
):
289306
dup_ids.append(i)
307+
mapping[len(dup_ids) - 1] = i
290308

291309
kpoints = kpoints[dup_ids]
292310

@@ -297,7 +315,7 @@ def force_branches(bandstructure):
297315
if len(bandstructure.projections) != 0:
298316
projections[spin] = bandstructure.projections[spin][:, dup_ids]
299317

300-
return type(bandstructure)(
318+
bs = type(bandstructure)(
301319
kpoints,
302320
eigenvals,
303321
bandstructure.lattice_rec,
@@ -306,6 +324,9 @@ def force_branches(bandstructure):
306324
structure=bandstructure.structure,
307325
projections=projections,
308326
)
327+
if return_mapping:
328+
return bs, mapping
329+
return bs
309330

310331

311332
def string_to_spin(spin_string):

sumo/electronic_structure/dos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def load_dos(
9999
else:
100100
vr = vasprun
101101

102-
band = vr.get_band_structure()
102+
band = vr.get_band_structure(efermi="smart")
103103
dos = vr.complete_dos
104104

105105
dos, band = _scissor_dos(dos, band, scissor)

0 commit comments

Comments
 (0)