From 726a5faa3233bd5c7efb03f5b2e72e676274a8ba Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Date: Fri, 3 Jan 2025 11:08:06 +0100
Subject: [PATCH] TST (string): clean-up interpolate tests

---
 .../tests/frame/methods/test_interpolate.py   | 73 +++++++++++++------
 1 file changed, 51 insertions(+), 22 deletions(-)

diff --git a/pandas/tests/frame/methods/test_interpolate.py b/pandas/tests/frame/methods/test_interpolate.py
index 09d1cc9a479b2..388000b3cce52 100644
--- a/pandas/tests/frame/methods/test_interpolate.py
+++ b/pandas/tests/frame/methods/test_interpolate.py
@@ -1,8 +1,6 @@
 import numpy as np
 import pytest
 
-from pandas._config import using_string_dtype
-
 import pandas.util._test_decorators as td
 
 from pandas import (
@@ -64,7 +62,7 @@ def test_interpolate_inplace(self, frame_or_series, request):
         assert np.shares_memory(orig, obj.values)
         assert orig.squeeze()[1] == 1.5
 
-    def test_interp_basic(self, using_infer_string):
+    def test_interp_with_non_numeric(self, using_infer_string):
         df = DataFrame(
             {
                 "A": [1, 2, np.nan, 4],
@@ -73,43 +71,74 @@ def test_interp_basic(self, using_infer_string):
                 "D": list("abcd"),
             }
         )
+        df_orig = df.copy()
+        expected = DataFrame(
+            {
+                "A": [1.0, 2.0, 3.0, 4.0],
+                "B": [1.0, 4.0, 9.0, 9.0],
+                "C": [1, 2, 3, 5],
+                "D": list("abcd"),
+            }
+        )
+
         dtype = "str" if using_infer_string else "object"
         msg = f"[Cc]annot interpolate with {dtype} dtype"
         with pytest.raises(TypeError, match=msg):
             df.interpolate()
+        tm.assert_frame_equal(df, df_orig)
 
-        cvalues = df["C"]._values
-        dvalues = df["D"].values
         with pytest.raises(TypeError, match=msg):
             df.interpolate(inplace=True)
+        # columns A and B already get interpolated before we hit the error
+        tm.assert_frame_equal(df, expected)
+
+    def test_interp_basic(self):
+        df = DataFrame(
+            {
+                "A": [1, 2, np.nan, 4],
+                "B": [1, 4, 9, np.nan],
+                "C": [1, 2, 3, 5],
+            }
+        )
+        df_orig = df.copy()
+        expected = DataFrame(
+            {
+                "A": [1.0, 2.0, 3.0, 4.0],
+                "B": [1.0, 4.0, 9.0, 9.0],
+                "C": [1, 2, 3, 5],
+            }
+        )
+        result = df.interpolate()
+        tm.assert_frame_equal(result, expected)
+
+        # check we didn't operate inplace GH#45791
+        tm.assert_frame_equal(df, df_orig)
+        bvalues = df["B"]._values
+        cvalues = df["C"]._values
+        assert not tm.shares_memory(bvalues, result["B"]._values)
+        assert tm.shares_memory(cvalues, result["C"]._values)
+
+        res = df.interpolate(inplace=True)
+        assert res is None
+        tm.assert_frame_equal(df, expected)
 
         # check we DID operate inplace
+        assert tm.shares_memory(df["B"]._values, bvalues)
         assert tm.shares_memory(df["C"]._values, cvalues)
-        assert tm.shares_memory(df["D"]._values, dvalues)
 
-    @pytest.mark.xfail(
-        using_string_dtype(), reason="interpolate doesn't work for string"
-    )
-    def test_interp_basic_with_non_range_index(self, using_infer_string):
+    def test_interp_basic_with_non_range_index(self):
         df = DataFrame(
             {
                 "A": [1, 2, np.nan, 4],
                 "B": [1, 4, 9, np.nan],
                 "C": [1, 2, 3, 5],
-                "D": list("abcd"),
             }
         )
-
-        msg = "DataFrame cannot interpolate with object dtype"
-        if not using_infer_string:
-            with pytest.raises(TypeError, match=msg):
-                df.set_index("C").interpolate()
-        else:
-            result = df.set_index("C").interpolate()
-            expected = df.set_index("C")
-            expected.loc[3, "A"] = 2.66667
-            expected.loc[5, "B"] = 9
-            tm.assert_frame_equal(result, expected)
+        result = df.set_index("C").interpolate()
+        expected = df.set_index("C")
+        expected.loc[3, "A"] = 2.66667
+        expected.loc[5, "B"] = 9
+        tm.assert_frame_equal(result, expected)
 
     def test_interp_empty(self):
         # https://github.com/pandas-dev/pandas/issues/35598