Skip to content

Commit 0eaf60a

Browse files
Improve subsampling (#355)
* Move to pyproject and fix subsampling * Remove pyproject toml and add versioning * Make deterministic
1 parent cdc3f0d commit 0eaf60a

File tree

3 files changed

+77
-18
lines changed

3 files changed

+77
-18
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ docs/_build/
1212
.env
1313
.venv
1414
.vscode
15-
venv
15+
venv
16+
.uv/

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Subsampling logic.

policyengine_core/simulations/simulation.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,8 +1501,33 @@ def to_input_dataframe(
15011501

15021502
return df
15031503

1504+
def to_input_dict(self) -> dict:
1505+
"""Exports a dictionary which can be loaded back to a new Simulation to reproduce the same results.
1506+
1507+
Returns:
1508+
dict: The dictionary containing the input values.
1509+
"""
1510+
data = {}
1511+
1512+
for variable in self.tax_benefit_system.variables:
1513+
data[variable] = {}
1514+
for period in self.get_holder(variable).get_known_periods():
1515+
values = self.calculate(variable, period, map_to="person")
1516+
if values is not None:
1517+
data[variable][str(period)] = values.tolist()
1518+
1519+
if len(data[variable]) == 0:
1520+
del data[variable]
1521+
1522+
return data
1523+
15041524
def subsample(
1505-
self, n=None, frac=None, seed=None, time_period=None
1525+
self,
1526+
n=None,
1527+
frac=None,
1528+
seed=None,
1529+
time_period=None,
1530+
quantize_weights: bool = True,
15061531
) -> "Simulation":
15071532
"""Quantize the simulation to a smaller size by sampling households.
15081533
@@ -1515,6 +1540,7 @@ def subsample(
15151540
Returns:
15161541
Simulation: The quantized simulation.
15171542
"""
1543+
default_calculation_period = self.default_calculation_period
15181544
# Set default key if not provided
15191545
if seed is None:
15201546
seed = self.dataset.name
@@ -1529,6 +1555,7 @@ def subsample(
15291555
# Extract time period from DataFrame columns
15301556
df_time_period = df.columns.values[0].split("__")[1]
15311557
df_household_id_column = f"household_id__{df_time_period}"
1558+
df_person_id_column = f"person_id__{df_time_period}"
15321559

15331560
# Determine the appropriate household weight column
15341561
if f"household_weight__{time_period}" in df.columns:
@@ -1545,34 +1572,59 @@ def subsample(
15451572
n = int(len(h_ids) * frac)
15461573
h_weights = pd.Series(h_df[household_weight_column].values)
15471574

1548-
if n > len(h_weights):
1549-
# Don't need to subsample!
1550-
return self
1575+
frac = n / len(h_ids)
15511576

15521577
# Seed the random number generators for reproducibility
15531578
random.seed(str(seed))
15541579
state = random.randint(0, 2**32 - 1)
15551580
np.random.seed(state)
15561581

1582+
h_ids = h_ids[h_weights > 0]
1583+
h_weights = h_weights[h_weights > 0]
1584+
15571585
# Sample household IDs based on their weights
1558-
chosen_household_ids = np.random.choice(
1559-
h_ids,
1560-
n,
1561-
p=h_weights.values / h_weights.values.sum(),
1562-
replace=False,
1586+
chosen_household_ids = pd.Series(
1587+
np.random.choice(
1588+
h_ids,
1589+
n,
1590+
p=(
1591+
h_weights.values / h_weights.values.sum()
1592+
if quantize_weights
1593+
else None
1594+
),
1595+
replace=True,
1596+
)
15631597
)
15641598

1565-
# Filter DataFrame to include only the chosen households
1566-
df = df[df[df_household_id_column].isin(chosen_household_ids)]
1599+
household_id_to_count = {}
1600+
for household_id in chosen_household_ids:
1601+
if household_id not in household_id_to_count:
1602+
household_id_to_count[household_id] = 0
1603+
household_id_to_count[household_id] += 1
15671604

1568-
# Adjust household weights to maintain the total weight
1569-
df[household_weight_column] *= (
1570-
h_weights.sum()
1571-
/ df.groupby(df_household_id_column)
1572-
.first()[household_weight_column]
1573-
.sum()
1605+
subset_df = df[
1606+
df[df_household_id_column].isin(chosen_household_ids)
1607+
].copy()
1608+
1609+
household_counts = subset_df[df_household_id_column].map(
1610+
lambda x: household_id_to_count.get(x, 0)
15741611
)
15751612

1613+
# Adjust household weights to maintain the total weight
1614+
1615+
for col in subset_df.columns:
1616+
if "weight__" in col:
1617+
target_total_weight = df[col].values.sum()
1618+
if not quantize_weights:
1619+
subset_df[col] *= household_counts.values
1620+
else:
1621+
subset_df[col] = household_counts.values
1622+
subset_df[col] *= (
1623+
target_total_weight / subset_df[col].values.sum()
1624+
)
1625+
1626+
df = subset_df
1627+
15761628
# Update the dataset and rebuild the simulation
15771629
self.dataset = Dataset.from_dataframe(df, self.dataset.time_period)
15781630
self.build_from_dataset()
@@ -1584,6 +1636,8 @@ def subsample(
15841636
].tax_benefit_system
15851637
self.branches["baseline"] = self.clone()
15861638
self.branches["tax_benefit_system"] = baseline_tax_benefit_system
1639+
1640+
self.default_calculation_period = default_calculation_period
15871641
return self
15881642

15891643

0 commit comments

Comments
 (0)