Belief polarization

world building

What we learn from new information is mediated by the mental models we use to interpret the data.

In priors and explanation we saw how prior beliefs can lead to different explanations of the same observation — two people with different priors might explain the same event in radically different ways.

In that model, the two observers update their beliefs in accordance with Bayes’ rule. Because the observers had different prior beliefs, they inferred different explanations for the same data. However, while their explanations were different, both observers updated their beliefs in the same direction, but to different degrees. As more relevant information accumulates, the consequence of priors should diminish. This idea underpins the optimism of rational discourse: shared evidence should lead to shared understanding.

But this doesn’t always seem to happen. Sometimes people will have more polarized beliefs after observing the same information. How is this possible? Is there a “rational” explanation?

This polarization can arise because we do not merely learn patterns of data; we learn how to interpret data. Each new observation can reshape the underlying causal models through which we interpret future information. Humans excel at integrating prior knowledge with observations, but “rational” does not necessitate “correct”.

import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.norm import pdf as normpdf

normpdfjit = jax.jit(normpdf)

PolicePerformance = jnp.linspace(0, 1, 10+1, endpoint=True)

Causal = jnp.linspace(-1, 1, 40+1, endpoint=True)

Arrests = jnp.linspace(0, 1, 10+1, endpoint=True)

@jax.jit
def arrests_pdf(arrests, performance, causal_link):
    arrests_mu = causal_link * (performance - 0.5) + 0.5
    arrests_sigma = 0.1
    return normpdf(arrests, arrests_mu, arrests_sigma)

@jax.jit
def reported_cl_pdf(reported, real, bias=0.0):
    return normpdf(reported, real + real*bias, 2.0)

@memo
def viewerModel[
    _prior_expectation_cl: Causal,
](
    reported_cl_observed, 
    nobs,
):
    viewer: knows(
        _prior_expectation_cl,
    )
    viewer: thinks[
        police: given(causal_link in Causal, wpp=(
            viewerModel[causal_link](reported_cl_observed, nobs - 1) 
            if nobs > 0 else 1)),
        police: chooses(performance in PolicePerformance, wpp=1),
        police: chooses(arrests in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
        news: knows(police.causal_link),
        news: chooses(reported_cl in Causal, wpp=(
            reported_cl_pdf(reported_cl, police.causal_link)
        )),
    ]
    viewer: observes_event(wpp=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))

    return viewer[ 
        Pr[
            police.causal_link == _prior_expectation_cl
        ] 
    ]
reported = Causal[4]
for nobs_ in range(10):
    res_viewer = viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata = res_viewer.aux.xarray
    print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = -0.06499629467725754
nobs: 2 E = -0.12871909141540527
nobs: 3 E = -0.1903265118598938
nobs: 4 E = -0.2491239309310913
nobs: 5 E = -0.3045755922794342
nobs: 6 E = -0.35632815957069397
nobs: 7 E = -0.4041990637779236
nobs: 8 E = -0.4481450617313385
nobs: 9 E = -0.48824718594551086
nobs: 10 E = -0.5246806144714355
reported = Causal[-4]
for nobs_ in range(10):
    res_viewer = viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata = res_viewer.aux.xarray
    print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = 0.06774601340293884
nobs: 2 E = 0.13407555222511292
nobs: 3 E = 0.19804809987545013
nobs: 4 E = 0.2588895261287689
nobs: 5 E = 0.31602537631988525
nobs: 6 E = 0.3690882623195648
nobs: 7 E = 0.41790276765823364
nobs: 8 E = 0.46246659755706787
nobs: 9 E = 0.5029038190841675
nobs: 10 E = 0.5394369959831238
@memo
def performanceInferred[
    _arrests_observed: Arrests, 
](reported_cl_observed, nobs):
    viewer: knows(_arrests_observed)
    viewer: thinks[
        police: given(causal_link in Causal, wpp=(
            viewerModel[causal_link](reported_cl_observed, nobs) 
            if nobs > 0 else 1
        )),
        police: chooses(performance in PolicePerformance, wpp=1),
        police: chooses(arrests in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
        news: knows(police.causal_link),
        news: chooses(reported_cl in Causal, wpp=(
            reported_cl_pdf(reported_cl, police.causal_link)
        )),
    ]
    viewer: observes_event(wpp=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))
    viewer: observes [police.arrests] is _arrests_observed

    return viewer[ E[police.performance] ]

reported = Causal[4]
nobs_ = 4
_ = performanceInferred(reported, nobs_, print_table=True)
+----------------------------+----------------------+
| _arrests_observed: Arrests | performanceInferred  |
+----------------------------+----------------------+
| 0.0                        | 0.8483713865280151   |
| 0.10000000149011612        | 0.7999858260154724   |
| 0.20000000298023224        | 0.7301038503646851   |
| 0.30000001192092896        | 0.645521342754364    |
| 0.4000000059604645         | 0.5646162629127502   |
| 0.5                        | 0.4999997913837433   |
| 0.6000000238418579         | 0.43538185954093933  |
| 0.699999988079071          | 0.35447952151298523  |
| 0.800000011920929          | 0.2698954939842224   |
| 0.9000000357627869         | 0.20001395046710968  |
| 1.0                        | 0.1516285240650177   |
+----------------------------+----------------------+
fox_reported = Causal[4].item()
msnbc_reported = Causal[-4].item()
arrests_observed = Arrests[-2].item()

nobs_list = list(range(10))

foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests = list()
for nobs_ in nobs_list:
    res_fox = performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_fox = res_fox.aux.xarray
    foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())

    res_msnbc_viewer = performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_msnbc = res_msnbc_viewer.aux.xarray
    msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())

fig, ax = plt.subplots()
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='Fox')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label='MSNBC')
_ = ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
ax.legend()

fox_reported = Causal[4].item()
msnbc_reported = Causal[-4].item()
arrests_observed = Arrests[2].item()

nobs_list = list(range(10))

foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests = list()
for nobs_ in nobs_list:
    res_fox = performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_fox = res_fox.aux.xarray
    foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())

    res_msnbc_viewer = performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_msnbc = res_msnbc_viewer.aux.xarray
    msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())

fig, ax = plt.subplots()
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='Fox')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label='MSNBC')
_ = ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
ax.legend()


fig, ax = plt.subplots()
for reported in [Causal[4].item()]:
    for nobs_ in [0, 1, 2, 4, 6, 8]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()

fig, ax = plt.subplots()
for reported in [Causal[-4].item()]:
    for nobs_ in [0, 1, 2, 4, 6, 8]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()

fig, ax = plt.subplots()
for reported in jnp.linspace(-1, 1, 7, endpoint=True):
    for nobs_ in [3]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()


%reset -f
import sys
import platform
import importlib.metadata

print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())

print("\nPackages:")
for name, version in sorted(
    ((dist.metadata["Name"], dist.version) for dist in importlib.metadata.distributions()),
    key=lambda x: x[0].lower()  # Sort case-insensitively
):
    print(f"{name}=={version}")
Python: 3.14.2 (main, Dec  5 2025, 21:11:58) [Clang 21.1.4 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64

Packages:
annotated-types==0.7.0
anyio==4.12.0
appnope==0.1.4
argon2-cffi==25.1.0
argon2-cffi-bindings==25.1.0
arrow==1.4.0
astroid==4.0.2
asttokens==3.0.1
async-lru==2.0.5
attrs==25.4.0
babel==2.17.0
beautifulsoup4==4.14.3
bleach==6.3.0
certifi==2025.11.12
cffi==2.0.0
cfgv==3.5.0
charset-normalizer==3.4.4
click==8.3.1
comm==0.2.3
contourpy==1.3.3
cycler==0.12.1
debugpy==1.8.19
decorator==5.2.1
defusedxml==0.7.1
dill==0.4.0
distlib==0.4.0
executing==2.2.1
fastjsonschema==2.21.2
filelock==3.20.1
fonttools==4.61.1
fqdn==1.5.1
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
identify==2.6.15
idna==3.11
importlib_metadata==8.7.0
ipykernel==7.1.0
ipython==9.8.0
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.8
isoduration==20.11.0
isort==7.0.0
jax==0.8.1
jaxlib==0.8.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.5.3
json5==0.12.1
jsonpointer==3.0.0
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.3.0
jupyter_client==8.7.0
jupyter_core==5.9.1
jupyter_server==2.17.0
jupyter_server_terminals==0.5.3
jupyterlab==4.5.1
jupyterlab_pygments==0.3.0
jupyterlab_server==2.28.0
jupyterlab_widgets==3.0.16
kiwisolver==1.4.9
lark==1.3.1
MarkupSafe==3.0.3
matplotlib==3.10.8
matplotlib-inline==0.2.1
mccabe==0.7.0
memo-lang==1.2.7
mistune==3.1.4
ml_dtypes==0.5.4
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.6.1
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.3.5
numpy-typing-compat==20251206.2.3
opt_einsum==3.4.0
optype==0.15.0
packaging==25.0
pandas==2.3.3
pandas-stubs==2.3.3.251201
pandocfilters==1.5.1
parso==0.8.5
pexpect==4.9.0
pillow==12.0.0
platformdirs==4.5.1
plotly==5.24.1
pre_commit==4.5.1
prometheus_client==0.23.1
prompt_toolkit==3.0.52
psutil==7.1.3
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.23
pydantic==2.12.5
pydantic_core==2.41.5
Pygments==2.19.2
pygraphviz==1.14
pylint==4.0.4
pyparsing==3.2.5
python-dateutil==2.9.0.post0
python-dotenv==1.2.1
python-json-logger==4.0.0
pytz==2025.2
PyYAML==6.0.3
pyzmq==27.1.0
referencing==0.37.0
requests==2.32.5
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rfc3987-syntax==1.1.0
rpds-py==0.30.0
ruff==0.14.9
scikit-learn==1.8.0
scipy==1.16.3
scipy-stubs==1.16.3.3
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==80.9.0
six==1.17.0
soupsieve==2.8
SQLAlchemy==2.0.45
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.1.2
terminado==0.18.1
threadpoolctl==3.6.0
tinycss2==1.4.0
toml==0.10.2
tomlkit==0.13.3
tornado==6.5.4
tqdm==4.67.1
traitlets==5.14.3
types-pytz==2025.2.0.20251108
typing-inspection==0.4.2
typing_extensions==4.15.0
tzdata==2025.3
uri-template==1.3.0
urllib3==2.6.2
virtualenv==20.35.4
wcwidth==0.2.14
webcolors==25.10.0
webencodings==0.5.1
websocket-client==1.9.0
widgetsnbextension==4.0.15
xarray==2025.12.0
zipp==3.23.0