import numpy as np import pytest from pandas.compat import HAS_PYARROW from pandas.core.dtypes.cast import find_common_type import pandas as pd import pandas._testing as tm from pandas.util.version import Version @pytest.mark.parametrize( "to_concat_dtypes, result_dtype", [ # same types ([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)), ([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)), ([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)), ([("python", np.nan), ("python", np.nan)], ("python", np.nan)), # pyarrow preference ([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)), # NA preference ([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)), ], ) def test_concat_series(request, to_concat_dtypes, result_dtype): if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW: pytest.skip("Could not import 'pyarrow'") ser_list = [ pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value)) for storage, na_value in to_concat_dtypes ] result = pd.concat(ser_list, ignore_index=True) expected = pd.Series( ["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype) ) tm.assert_series_equal(result, expected) # order doesn't matter for result result = pd.concat(ser_list[::1], ignore_index=True) tm.assert_series_equal(result, expected) def test_concat_with_object(string_dtype_arguments): # _get_common_dtype cannot inspect values, so object dtype with strings still # results in object dtype result = pd.concat( [ pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)), pd.Series(["a", "b", None], dtype=object), ] ) assert result.dtype == np.dtype("object") def test_concat_with_numpy(string_dtype_arguments): # common type with a numpy string dtype always preserves the pandas string dtype dtype = pd.StringDtype(*string_dtype_arguments) assert find_common_type([dtype, np.dtype("U")]) == dtype assert find_common_type([np.dtype("U"), dtype]) == dtype assert find_common_type([dtype, np.dtype("U10")]) == dtype assert find_common_type([np.dtype("U10"), dtype]) == dtype # with any other numpy dtype -> object assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object") assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object") if Version(np.__version__) >= Version("2"): assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype