MMM#
- class pymc_marketing.mmm.multidimensional.MMM(*, date_column=FieldInfo(annotation=NoneType, required=True, description='Column name of the date variable.'), channel_columns=FieldInfo(annotation=NoneType, required=True, description='Column names of the media channel variables.', metadata=[MinLen(min_length=1)]), target_column=FieldInfo(annotation=NoneType, required=False, default='y', description='The name of the target column.'), adstock=FieldInfo(annotation=NoneType, required=True, description='Type of adstock transformation to apply.'), saturation=FieldInfo(annotation=NoneType, required=True, description='The saturation transformation to apply to the channel data.'), time_varying_intercept=False, time_varying_media=False, dims=FieldInfo(annotation=NoneType, required=False, default=None, description='Additional dimensions for the model.'), scaling=FieldInfo(annotation=NoneType, required=False, default=None, description='Scaling configuration for the model.'), model_config=FieldInfo(annotation=NoneType, required=False, default=None, description='Configuration settings for the model.'), sampler_config=FieldInfo(annotation=NoneType, required=False, default=None, description='Configuration settings for the sampler.'), control_columns=None, yearly_seasonality=None, adstock_first=True, dag=FieldInfo(annotation=NoneType, required=False, default=None, description='Optional DAG provided as a string Dot format for causal identification.'), treatment_nodes=FieldInfo(annotation=NoneType, required=False, default=None, description='Column names of the variables of interest to identify causal effects on outcome.'), outcome_node=FieldInfo(annotation=NoneType, required=False, default=None, description='Name of the outcome variable.'), cost_per_unit=FieldInfo(annotation=NoneType, required=False, default=None, description='Cost per unit conversion factors for non-spend channels. Wide-format DataFrame where rows are (date, *custom_dims) combinations and columns are channel names containing cost values. Not all model channels need to appear; missing channels default to 1.0 (already in spend units).'))[source]#
Marketing Mix Model class for estimating the impact of marketing channels on a target variable.
Given a target variable \(y_{t}\) (e.g. sales or conversions), media variables \(x_{m, t}\) (e.g. impressions, clicks, or costs), and a set of control covariates \(z_{c, t}\) (e.g. holidays, pricing), we consider a Bayesian linear model of the form:
\[y_{t} = \alpha + \sum_{m=1}^{M}\beta_{m}\,f_{m}\!\bigl( \{x_{m,s}\}_{s \leq t}\bigr) + \sum_{c=1}^{C}\gamma_{c}\, z_{c, t} + \varepsilon_{t},\]where \(\alpha\) is the intercept, \(f_{m}\) is a media transformation function that maps the history of channel \(m\) up to time \(t\) to a scalar contribution, capturing adstock (carry-over) and saturation effects, and \(\varepsilon_{t} \sim \mathcal{N}(0, \sigma^{2})\).
The model supports \(K \geq 0\) additional panel dimensions (e.g. geography, brand) specified via the
dimsparameter. When \(K > 0\), every variable — the target, media inputs, controls — and all parameters (\(\alpha\), \(\beta_{m}\), \(\gamma_{c}\), \(\sigma\), and the parameters of \(f_{m}\)) are implicitly indexed over the Cartesian product of those dimensions. For example, withdims=("geo",)each parameter is geo-specific — \(y_{t,g}\), \(\alpha_{g}\), \(\beta_{m,g}\), etc. — but they share hierarchical priors so that information is partially pooled across geographies. Whendims=("geo", "brand"), every quantity is indexed by \((t, g, b)\). The equation above is written for a single slice of these dimensions; the full model is their product over all dimension combinations.- Attributes:
- date_column
str The name of the column representing the date in the dataset.
- channel_columns
list[str] A list of column names representing the marketing channels.
- target_column
str, optional The name of the column representing the target variable in the dataset. Defaults to
"y".- adstock
AdstockTransformation The adstock transformation to apply to the channel data.
- saturation
SaturationTransformation The saturation transformation to apply to the channel data.
- time_varying_interceptbool or
HSGPBase Whether to use a time-varying intercept in the model, or an
HSGPBaseinstance specifying dims and priors.- time_varying_mediabool or
HSGPBase Whether to use time-varying effects for media channels, or an
HSGPBaseinstance specifying dims and priors.- dims
tuple[str, …] orNone Additional panel dimensions for the model (e.g.
("geo",)). One categorical column per dimension must be present in the dataset. Data must be rectangular across these dimensions (i.e. the same dates for every combination).- scaling
ScalingordictorNone Scaling methods for the target variable and the marketing channels. Defaults to max-absolute scaling for both.
- model_config
dictorNone Configuration settings for the model priors and likelihood.
- sampler_config
dictorNone Configuration settings for the sampler.
- control_columns
list[str] orNone Column names of control covariates to include in the model.
- yearly_seasonality
intorNone Number of Fourier modes for yearly seasonality.
- adstock_firstbool
Whether to apply adstock before saturation (default
True).
- date_column
Notes
Before fitting, the target variable and media channels are scaled (by default using max-absolute scaling). Control variables are not scaled automatically — apply your own preprocessing if needed.
Yearly seasonality can be added as Fourier modes via the
yearly_seasonalityparameter.The model can be calibrated with:
Custom priors for any parameter via
model_config.Lift-test measurements added through
add_lift_test_measurements().
For details on a vanilla implementation in PyMC see [2].
References
[1]Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017).
Methods
MMM.__init__(*[, date_column, ...])Define the constructor method.
MMM.add_cost_per_target_calibration(data, ...)Calibrate cost-per-target using constraints via
pm.Potential.MMM.add_events(df_events, prefix, effect)Add event effects to the model.
MMM.add_lift_test_measurements(df_lift_test)Add lift tests to the model.
Add a pm.Deterministic variable to the model that multiplies by the scaler.
MMM.approximate_fit(X[, y, progressbar, ...])Fit a model using Variational Inference and return InferenceData.
MMM.attrs_to_init_kwargs(attrs)Convert the idata attributes to the model initialization kwargs.
MMM.build_from_idata(idata)Rebuild the model from an
InferenceDataobject.MMM.build_model(X, y, **kwargs)Build a probabilistic model using PyMC for marketing mix modeling.
MMM.create_fit_data(X, y)Create a fit dataset aligned on date and present dimensions.
Return the idata attributes for the model.
MMM.fit(X[, y, progressbar, random_seed])Fit the model and inject cost_per_unit metadata if provided.
MMM.forward_pass(x, dims)Transform channel input into target contributions of each channel.
Return the saved scaling factors as xarray DataArrays.
MMM.graphviz(**kwargs)Get the graphviz representation of the model.
MMM.idata_to_init_kwargs(idata)Create the model configuration and sampler configuration from the InferenceData to keyword arguments.
MMM.load(fname[, check])Create a ModelBuilder instance from a file.
MMM.load_from_idata(idata[, check])Create a ModelBuilder instance from an InferenceData object.
Post-sample model transformation in order to store the HSGP state from fit.
MMM.predict([X, extend_idata])Use a model to predict on unseen data and return point prediction of all the samples.
MMM.predict_posterior([X, extend_idata, ...])Generate posterior predictive samples on unseen data.
MMM.predict_proba([X, extend_idata, combined])Alias for
predict_posterior, for consistency with scikit-learn probabilistic estimators.MMM.sample_adstock_curve([amount, ...])Sample adstock curves from posterior parameters.
MMM.sample_posterior_predictive([X, ...])Sample from the model's posterior predictive distribution.
MMM.sample_prior_predictive([X, y, samples, ...])Sample from the model's prior predictive distribution.
MMM.sample_saturation_curve([max_value, ...])Sample saturation curves from posterior parameters.
MMM.save(fname, **kwargs)Save the model's inference data to a file.
MMM.set_cost_per_unit(cost_per_unit)Set or update cost_per_unit metadata for the fitted model.
MMM.set_idata_attrs([idata])Set attributes on an InferenceData object.
MMM.table(**model_table_kwargs)Get the summary table of the model.
Attributes
dataGet data wrapper for InferenceData access and manipulation.
default_model_configDefine the default model configuration.
default_sampler_configDefault sampler configuration.
fit_resultGet the posterior fit_result.
idGenerate a unique hash value for the model.
incrementalityAccess incrementality and counterfactual analysis functionality.
output_varplotUse the MMMPlotSuite to plot the results.
plot_interactiveAccess interactive Plotly plotting functionality.
posteriorposterior_predictivepredictionspriorprior_predictivesensitivityAccess sensitivity analysis functionality.
summaryAccess summary DataFrame generation functionality.
versionidatasampler_configmodel_config