import pytest
import pandas as pd
import polars as pl
import numpy as np
# Import the function to be tested
# from hogehoge import generate_lag_variable_with_group_key # Replace with actual module name
@pytest.fixture
def sample_pandas_df():
"""Fixture to provide a sample Pandas DataFrame"""
return pd.DataFrame({
"entity": ["A", "A", "A", "B", "B", "B"],
"time": ["2023-01-01", "2023-01-02", "2023-01-03",
"2023-01-01", "2023-01-02", "2023-01-03"],
"value": [10, 20, 30, 100, 200, 300]
}).assign(time=lambda df: pd.to_datetime(df["time"])) # Ensure datetime type
@pytest.fixture
def sample_polars_df():
"""Fixture to provide a sample Polars DataFrame"""
return pl.DataFrame({
"entity": ["A", "A", "A", "B", "B", "B"],
"time": ["2023-01-01", "2023-01-02", "2023-01-03",
"2023-01-01", "2023-01-02", "2023-01-03"],
"value": [10, 20, 30, 100, 200, 300]
}).with_columns(pl.col("time").str.to_date()) # Ensure datetime type
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_generate_lag_variable_basic(df_type, sample_pandas_df, sample_polars_df):
"""Test basic functionality with Pandas and Polars"""
df = sample_pandas_df if df_type == "pandas" else sample_polars_df
result = generate_lag_variable_with_group_key(
df=df,
target_column="value",
sort_key=["time"],
group_key=["entity"],
lag_size=1
)
assert "value_1lag" in result.columns, "Lagged column not found!"
expected_values_pandas = [np.nan, 10.0, 20.0, np.nan, 100.0, 200.0] # Expected shifted values
expected_values_polars = [None, 10, 20, None, 100, 200] # Expected shifted values
if df_type == "pandas":
assert np.array_equal(result["value_1lag"].to_list(), expected_values_pandas, equal_nan=True)
else:
assert result["value_1lag"].to_list() == expected_values_polars
def test_generate_lag_variable_custom_column(sample_pandas_df):
"""Test if custom column name works correctly"""
result = generate_lag_variable_with_group_key(
df=sample_pandas_df,
target_column="value",
sort_key=["time"],
group_key=["entity"],
lag_size=1,
lagged_col_name="custom_lag"
)
assert "custom_lag" in result.columns, "Custom lag column name not applied!"
def test_generate_lag_variable_with_descending_order(sample_pandas_df):
"""Test sorting order with descending time"""
result = generate_lag_variable_with_group_key(
df=sample_pandas_df,
target_column="value",
sort_key=["time"],
group_key=["entity"],
lag_size=1,
ascending=False # Reverse sorting order
)
expected_values = [20.0, 30.0, np.nan, 200.0, 300.0, np.nan] # Because order is reversed
assert np.array_equal(result["value_1lag"].to_list(), expected_values, equal_nan=True)
def test_generate_lag_variable_invalid_input():
"""Test function raises TypeError for invalid input"""
with pytest.raises(TypeError):
generate_lag_variable_with_group_key(
df="not_a_dataframe", # Invalid type
target_column="value",
sort_key=["time"],
group_key=["entity"],
lag_size=1
)