from plotly.subplots import make_subplots
import plotly.graph_objects as go
import statsmodels.api as sm
X1, Y1 = np.random.multivariate_normal([0, 0], [[1, 0.4], [0.4, 1]], N).T
X2, Y2 = np.random.multivariate_normal([0, 0], [[1, 0.95], [0.95, 1]], N).T
# Set common x and y ranges
x_range = np.array([-3, 3]) # Define your desired x range
y_range = np.array([-3, 3]) # Define your desired y range
# regression
result_1_yx = sm.OLS(Y1, sm.add_constant(X1)).fit()
result_1_xy = sm.OLS(X1, sm.add_constant(Y1)).fit()
result_2_yx = sm.OLS(Y2, sm.add_constant(X2)).fit()
result_2_xy = sm.OLS(X2, sm.add_constant(Y2)).fit()
fig = make_subplots(rows=1, cols=2)
fig.add_trace(go.Scatter(x=X1, y=Y1, name="相関係数:0.4"), row=1, col=1)
fig.add_trace(
go.Scatter(
x=x_range,
y=result_1_yx.predict(sm.add_constant(x_range)),
mode="lines",
line=dict(color='gray'),
showlegend=False,
name="coef: {:.2f}".format(result_1_yx.params[1]),
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=result_1_xy.predict(sm.add_constant(y_range)),
y=y_range,
mode="lines",
line=dict(color='gray'),
showlegend=False,
name="coef: {:.2f}".format(result_1_xy.params[1]),
),
row=1,
col=1,
)
fig.add_trace(go.Scatter(x=X2, y=Y2, name="相関係数:0.95"), row=1, col=2)
fig.add_trace(
go.Scatter(
x=x_range,
y=result_2_yx.predict(sm.add_constant(x_range)),
mode="lines",
line=dict(color='gray'),
showlegend=False,
name="coef: {:.2f}".format(result_2_yx.params[1]),
),
row=1,
col=2,
)
fig.add_trace(
go.Scatter(
x=result_2_xy.predict(sm.add_constant(y_range)),
y=y_range,
mode="lines",
line=dict(color='gray'),
showlegend=False,
name="coef: {:.2f}".format(result_2_xy.params[1]),
),
row=1,
col=2,
)
fig.update_layout(title='相関係数水準に応じた回帰係数と逆回帰係数の比較')
fig.update_xaxes(
range=x_range, scaleanchor="y", scaleratio=1, row=1, col=1
) # Link x-axis of plot 1 to y-axis
fig.update_yaxes(
range=y_range, scaleanchor="x", scaleratio=1, row=1, col=1
) # Link y-axis of plot 1 to x-axis
fig.update_xaxes(
range=x_range, scaleanchor="y2", scaleratio=1, row=1, col=2
) # Link x-axis of plot 2 to y-axis
fig.update_yaxes(range=y_range, scaleanchor="x2", scaleratio=1, row=1, col=2)
fig.show()