Group毎に定義されたラグ変数の作成

python
前処理
Author

Ryo Nakagami

Published

2025-02-19

問題設定: ラグ変数の作成

Exercise 1

次のようなpandas.DataFrameを考えます

Code
import pandas as pd
import polars as pl

df = pd.DataFrame(
    {
        "entity_col": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "C"],
        "state": [1, 0, 1, 1, 1, 0, 0, 0, 0, 1],
        "time_col": pd.to_datetime([
            "2021-01-02",
            "2021-01-03",
            "2021-01-04",
            "2021-01-01",
            "2021-01-03",
            "2021-01-04",
            "2021-02-02",
            "2021-02-03",
            "2021-02-10",
            "2021-01-02",
        ]),
        "temp": [1, 2, 11, 13, 12, 14, 10, 9, 8, 0],
    }
)

df
entity_col state time_col temp
0 A 1 2021-01-02 1
1 A 0 2021-01-03 2
2 A 1 2021-01-04 11
3 A 1 2021-01-01 13
4 A 1 2021-01-03 12
5 A 0 2021-01-04 14
6 B 0 2021-02-02 10
7 B 0 2021-02-03 9
8 B 0 2021-02-10 8
9 C 1 2021-01-02 0

▶  実施したい処理

  • (entity_col, state)でGroup Keysとして,各Group内部で time_col の順番に応じて tempカラムについてラグ変数を作成したい
  • ラグ変数の名前は Defaultでは f"{target_column}_1lag, 指定があった場合はそれに倣うとする

前処理関数の実装

Solution: generate_lag_variable_with_group_key
from typing import Optional


def generate_lag_variable_with_group_key(
    df: pd.DataFrame | pl.DataFrame,
    target_column: str,
    sort_key: list[str],
    group_key: list[str],
    lag_size: int,
    ascending: list[bool] | bool = True,
    lagged_col_name: Optional[str] = None,
) -> pd.DataFrame | pl.DataFrame:
    """
    Generate a lagged variable in a DataFrame with a specified group key.

    This function creates a new column in the DataFrame that contains
    the lagged values of an existing column, grouped by specified keys.
    It supports both pandas and polars DataFrames.

    Parameters:
        df (pd.DataFrame | pl.DataFrame):
            The input DataFrame, either pandas or polars.

        col (str):
            The name of the column to generate the lagged variable from.

        sort_key (list[str]):
            The list of columns to sort the DataFrame by before generating the lagged variable.

        group_key (list[str]):
            The list of columns to group the DataFrame by before generating the lagged variable.

        lag_size (int):
            The number of periods to lag the variable.

        ascending (list[bool] | bool, optional):
            The sort order for each column in sort_key. Defaults to True.

        lagged_col_name (Optional[str], optional):
            The name of the new lagged column. If None, defaults to "{col}_{lag_size}lag".

    Returns:
        pd.DataFrame | pl.DataFrame: The DataFrame with the new lagged variable column.

    Raises:
        TypeError: If the input DataFrame is neither pandas.DataFrame nor polars.DataFrame.
    """

    # set up common variables
    if lagged_col_name is None:
        lagged_col_name = f"{target_column}_{lag_size}lag"

    result_sort_key = group_key + sort_key

    if isinstance(df, pd.DataFrame):
        # Sort the dataframe by the specified sort key
        df_sorted = df.sort_values(by=sort_key, ascending=ascending)

        # Generate the lagged variable
        df_sorted[lagged_col_name] = df_sorted.groupby(group_key)[target_column].shift(
            lag_size
        )

        # Sort the dataframe by the group key and sort key
        result_sort_key = group_key + sort_key
        df_result = df_sorted.sort_values(by=result_sort_key).reset_index(drop=True)

    elif isinstance(df, pl.DataFrame):
        if isinstance(ascending, list):
            order_reverse = [not x for x in ascending]
        else:
            order_reverse = not ascending

        df_sorted = df.sort(sort_key, descending=order_reverse)
        df_sorted = df_sorted.with_columns(
            pl.col(target_column)
            .shift(lag_size)
            .over(group_key)
            .alias(lagged_col_name)
        )
        df_result = df_sorted.sort(result_sort_key)

    else:
        raise TypeError(
            f"type(df) is {type(df)}: df should be pandas.DataFrame or polars.DataFrame"
        )

    return df_result

挙動確認

▶  pandas.DataFrame

generate_lag_variable_with_group_key(
    df=df,
    target_column="temp",
    sort_key=["time_col"],
    group_key=["entity_col", "state"],
    lag_size=1,
    ascending= [True]
)
entity_col state time_col temp temp_1lag
0 A 0 2021-01-03 2 NaN
1 A 0 2021-01-04 14 2.0
2 A 1 2021-01-01 13 NaN
3 A 1 2021-01-02 1 13.0
4 A 1 2021-01-03 12 1.0
5 A 1 2021-01-04 11 12.0
6 B 0 2021-02-02 10 NaN
7 B 0 2021-02-03 9 10.0
8 B 0 2021-02-10 8 9.0
9 C 1 2021-01-02 0 NaN

▶  polars.DataFrame

df_polars = pl.DataFrame(df)
generate_lag_variable_with_group_key(
    df=df_polars,
    target_column="temp",
    sort_key=["time_col"],
    group_key=["entity_col", "state"],
    lag_size=1,
)
shape: (10, 5)
entity_col state time_col temp temp_1lag
str i64 datetime[ns] i64 i64
"A" 0 2021-01-03 00:00:00 2 null
"A" 0 2021-01-04 00:00:00 14 2
"A" 1 2021-01-01 00:00:00 13 null
"A" 1 2021-01-02 00:00:00 1 13
"A" 1 2021-01-03 00:00:00 12 1
"A" 1 2021-01-04 00:00:00 11 12
"B" 0 2021-02-02 00:00:00 10 null
"B" 0 2021-02-03 00:00:00 9 10
"B" 0 2021-02-10 00:00:00 8 9
"C" 1 2021-01-02 00:00:00 0 null

Unit test with pytest

📘 テスト方針

▶  Test Examples

import pytest
import pandas as pd
import polars as pl
import numpy as np

# Import the function to be tested
# from hogehoge import generate_lag_variable_with_group_key  # Replace with actual module name

@pytest.fixture
def sample_pandas_df():
    """Fixture to provide a sample Pandas DataFrame"""
    return pd.DataFrame({
        "entity": ["A", "A", "A", "B", "B", "B"],
        "time": ["2023-01-01", "2023-01-02", "2023-01-03",
                 "2023-01-01", "2023-01-02", "2023-01-03"],
        "value": [10, 20, 30, 100, 200, 300]
    }).assign(time=lambda df: pd.to_datetime(df["time"]))  # Ensure datetime type

@pytest.fixture
def sample_polars_df():
    """Fixture to provide a sample Polars DataFrame"""
    return pl.DataFrame({
        "entity": ["A", "A", "A", "B", "B", "B"],
        "time": ["2023-01-01", "2023-01-02", "2023-01-03",
                 "2023-01-01", "2023-01-02", "2023-01-03"],
        "value": [10, 20, 30, 100, 200, 300]
    }).with_columns(pl.col("time").str.to_date())  # Ensure datetime type

@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_generate_lag_variable_basic(df_type, sample_pandas_df, sample_polars_df):
    """Test basic functionality with Pandas and Polars"""
    df = sample_pandas_df if df_type == "pandas" else sample_polars_df
    result = generate_lag_variable_with_group_key(
        df=df,
        target_column="value",
        sort_key=["time"],
        group_key=["entity"],
        lag_size=1
    )

    assert "value_1lag" in result.columns, "Lagged column not found!"
    expected_values_pandas = [np.nan, 10.0, 20.0, np.nan, 100.0, 200.0]  # Expected shifted values
    expected_values_polars = [None, 10, 20, None, 100, 200]  # Expected shifted values
    if df_type == "pandas":
        assert np.array_equal(result["value_1lag"].to_list(), expected_values_pandas, equal_nan=True)
    else:
        assert result["value_1lag"].to_list() == expected_values_polars 

def test_generate_lag_variable_custom_column(sample_pandas_df):
    """Test if custom column name works correctly"""
    result = generate_lag_variable_with_group_key(
        df=sample_pandas_df,
        target_column="value",
        sort_key=["time"],
        group_key=["entity"],
        lag_size=1,
        lagged_col_name="custom_lag"
    )
    assert "custom_lag" in result.columns, "Custom lag column name not applied!"

def test_generate_lag_variable_with_descending_order(sample_pandas_df):
    """Test sorting order with descending time"""
    result = generate_lag_variable_with_group_key(
        df=sample_pandas_df,
        target_column="value",
        sort_key=["time"],
        group_key=["entity"],
        lag_size=1,
        ascending=False  # Reverse sorting order
    )
    expected_values = [20.0, 30.0, np.nan, 200.0, 300.0, np.nan]  # Because order is reversed
    assert np.array_equal(result["value_1lag"].to_list(), expected_values, equal_nan=True)

def test_generate_lag_variable_invalid_input():
    """Test function raises TypeError for invalid input"""
    with pytest.raises(TypeError):
        generate_lag_variable_with_group_key(
            df="not_a_dataframe",  # Invalid type
            target_column="value",
            sort_key=["time"],
            group_key=["entity"],
            lag_size=1
        )

注意点 !

  • np.nan != np.nan であるため,np.nanを含むリストを比較する場合,== を直接使用しても正しく機能しません
  • NaN は未定義値であるため,NaN は自身と等しくない