How to Implement Multiple Linear Regression Using Statsmodels Library in Python?
Share
Condition for Implement Multiple Linear Regression using Statsmodels Library
Description: Multiple Linear Regression (MLR) is a statistical technique that models the relationship between two or more features and a response variable. It is used to predict a continuous dependent variable based on multiple independent variables. The statsmodels library in Python provides comprehensive tools for fitting and evaluating linear models, including MLR, and generating statistical summaries.
In MLR, the goal is to fit a linear relationship of the form: y=β0+β1x1+β2x2+⋯+βnxn+ϵ
where:
y is the dependent variable.
x1,x2,…,xn are the independent variables (predictors).
β0,β1,…,βn are the model coefficients.
ϵ is the error term.
Why Should We Use Multiple Linear Regression?
Prediction of Continuous Variables: MLR helps to predict a dependent variable that has a linear relationship with multiple independent variables.
Interpretation: It allows for understanding the impact of each predictor on the dependent variable.
Feature Selection: You can use MLR to determine which features (independent variables) contribute significantly to the model.
Statistical Testing: Through p-values and confidence intervals, you can statistically evaluate the significance of each feature.
Step-by-Step Process
Import Libraries: Import necessary libraries such as pandas, statsmodels, and matplotlib for visualization.
Data Preparation: Load and clean the dataset. Ensure no missing values, and possibly scale or normalize numerical variables.
Explore the Data: Conduct exploratory data analysis (EDA) to understand relationships between variables and detect multicollinearity.
Define the Model: Use the OLS (Ordinary Least Squares) method from statsmodels to define the linear regression model.
Fit the Model: Fit the model using the fit() method to obtain regression coefficients and performance metrics.
Model Evaluation: Check the model summary for p-values, R-squared values, and residuals to assess model performance and significance.
Interpret Results: Interpret the coefficients, significance levels, and predictions.
Sample Code
# Step 1: Import Libraries
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
# Step 2: Load Data (Using a built-in dataset for illustration)
data = sm.datasets.get_rdataset("mtcars").data
# Step 3: Prepare the Data
# Define independent variables (predictors) and dependent variable (target)
X = data[['mpg', 'hp', 'wt']] # Example features (miles per gallon, horsepower,weight)
y = data['qsec'] # Example target variable (quarter mile time)
# Add constant to the independent variables matrix (intercept)
X = sm.add_constant(X)
# Step 4: Define and Fit the Model
model = sm.OLS(y, X) # Ordinary Least Squares regression
results = model.fit() # Fit the model
# Step 5: Model Summary
print(results.summary())
# Step 6: Visualize the Residuals
plt.figure(figsize=(8, 6))
sns.residplot(x=results.fittedvalues, y=results.resid, lowess=True, line_kws={'color': 'red'})
plt.title('Residual Plot')
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.show()
# Step 7: Predictions
predictions = results.predict(X)
# Step 8: Prediction Metrics
mae = mean_absolute_error(y, predictions)
mse = mean_squared_error(y, predictions)
rmse = mean_squared_error(y, predictions, squared=False) # RMSE is the square root of MSE
r2 = r2_score(y, predictions)
print(f"Mean Absolute Error (MAE): {mae}")
print(f"Mean Squared Error (MSE): {mse}")
print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"R-squared: {r2}")
# Step 9: Plot Actual vs Predicted Values
plt.figure(figsize=(8, 6))
plt.scatter(y, predictions, color='blue')
plt.plot([y.min(), y.max()], [y.min(), y.max()], color='red', linestyle='--') # Diagonal line (perfect prediction)
plt.title('Actual vs Predicted Values')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.show()
# Step 10: Plot Histogram of Residuals
plt.figure(figsize=(8, 6))
sns.histplot(results.resid, kde=True, color='blue', bins=10)
plt.title('Histogram of Residuals')
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.show()