

---

#  Recursive Semantic Refinement Network (RSR-Net)

## Project Goal
The **Recursive Semantic Refinement Network (RSR-Net)** is an innovative approach to abstractive summarization. Instead of generating text word-by-word, RSR-Net is designed to iteratively refine a fixed-size semantic embedding of a summary until it converges to the desired output state, which is semantically close to the ground-truth summary. This technique leverages principles from Deep Equilibrium Models (DEQ) and Recurrent Neural Networks (RNNs).

The project uses the standard **CNN/DailyMail** dataset and **BART-base embeddings** for all inputs and targets.

---

## ðŸ’¡ Core Methodology: Iterative Embedding Correction

RSR-Net treats the summarization task as a **regression problem in a high-dimensional semantic space**.

1.  **Input Preparation:** The source article ($\mathbf{x}$) and the target summary ($\mathbf{y}_{\text{true}}$) are converted into fixed-size **BART-base encoder embeddings** (768 dimensions).
2.  **Recursion:** The core model iteratively consumes the document context ($\mathbf{x}$) and refines the current summary state ($\mathbf{y}$) and an internal latent state ($\mathbf{z}$).
3.  **Refinement:** After multiple recursive steps, the final state ($\mathbf{y}_{\text{hat}}$) represents the network's best prediction of the target summary embedding.
4.  **Loss:** Training minimizes the distance (using **Mean Squared Error**) between the predicted summary embedding ($\mathbf{y}_{\text{hat}}$) and the ground-truth summary embedding ($\mathbf{y}_{\text{true_emb}}$).

---

## ðŸ§  Model Architecture: `RecursionModel`

The `RecursionModel` is a simple feed-forward network at its core, designed for recurrent application:

| State | Role | Dimension (Adjusted) |
| :--- | :--- | :--- |
| **Input ($\mathbf{x}$)** | **Document Context** (Fixed for all steps) | 768 (BART $\text{d_model}$) |
| **State ($\mathbf{y}$)** | **Current Summary Embedding** (Refined output) | 768 (BART $\text{d_model}$) |
| **Latent ($\mathbf{z}$)** | **Internal Memory** (Accumulates context) | 64 ($\text{latent_dim}$) |
| **Combined Input** | $\text{torch.cat}([\mathbf{x}, \mathbf{y}, \mathbf{z}])$ | $768 + 768 + 64$ |

The network's output consists of a refined summary state ($\mathbf{y}_{\text{out}}$), an auxiliary output ($\mathbf{y}_{\text{aux}}$), and a new latent state ($\mathbf{z}_{\text{new}}$).

## ðŸ”„ Recursive Training Mechanism

The training process uses two nested recursive functions that implement a form of **Deep Supervision** to stabilize training and ensure convergence.

### 1. `latent_recursion(x, y, z, net, n=4)`

This is the **inner loop** that runs the core network $n$ times. It quickly pushes the state towards a stable point for the current input:
$$\mathbf{y}_{t+1}, \mathbf{z}_{t+1} = \text{net}(\mathbf{x}, \mathbf{y}_t, \mathbf{z}_t)$$
The output $\mathbf{y}$ and $\mathbf{z}$ from the final step are then passed to the outer loop.

### 2. `deep_recursion(x, y, z, net, n=4, T=3)`

This is the **outer loop** that enables stable gradient propagation:
* **Preconditioning ($\text{T}-1$ steps):** The states $\mathbf{y}$ and $\mathbf{z}$ are updated repeatedly using `latent_recursion` while detaching the gradient ($\text{with torch.no_grad}()$). This stabilizes the initial state for the final, critical step.
* **Final Step (1 step):** The network runs `latent_recursion` one last time **with the gradient enabled**, producing the final predicted embedding $\mathbf{y}_{\text{hat}}$ and a confidence score $\mathbf{q}_{\text{hat}}$.
* **State Update:** The resulting $\mathbf{y}$ and $\mathbf{z}$ are **detached** before being passed back to the main training loop, ensuring the model trains on sequential segments of the refinement process.

### Loss Function

The combined loss function for semantic refinement is:
$$
\text{Loss} = \text{MSE}(\mathbf{y}_{\text{hat}}, \mathbf{y}_{\text{true_emb}}) + 0.1 \times \text{BCE}(\mathbf{q}_{\text{hat}}, \mathbf{1})
$$

* **$\text{MSE Loss}$:** Measures the distance between the predicted summary embedding ($\mathbf{y}_{\text{hat}}$) and the target summary embedding ($\mathbf{y}_{\text{true_emb}}$).
* **$\text{Auxiliary Loss}$:** A **Binary Cross-Entropy (BCE)** loss applied to the sigmoid confidence $\mathbf{q}_{\text{hat}}$, which encourages the network to be highly confident (score $\to 1$) in its final prediction.

---




In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RecursionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, latent_dim=32, num_classes=3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        # Core network
        self.fc1 = nn.Linear(input_dim + num_classes + latent_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc_latent = nn.Linear(hidden_dim, latent_dim)
        self.fc_out = nn.Linear(hidden_dim, num_classes)
        self.fc_aux = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, y=None, z=None):
        batch_size = x.size(0)
        device = x.device
        if y is None:
            y = torch.zeros(batch_size, self.num_classes, device=device)
        if z is None:
            z = torch.zeros(batch_size, self.latent_dim, device=device)

        concat_input = torch.cat([x, y, z], dim=-1)
        h = F.relu(self.bn1(self.fc1(concat_input)))
        h = F.relu(self.bn2(self.fc2(h)))

        z_new = F.relu(self.fc_latent(h))
        y_out = self.fc_out(h)
        y_aux = self.fc_aux(h)
        return y_out, y_aux, z_new

def latent_recursion(x, y, z, net, n=4):
    for _ in range(n):
        y, _, z = net(x, y, z)
    return y, z

def deep_recursion(x, y, z, net, n=4, T=3):
    # Preconditioning steps (no gradient)
    for t in range(T-1):
        with torch.no_grad():
            y, z = latent_recursion(x, y, z, net, n)
            # Add print statements to show intermediate outputs during preconditioning
            print(f"Preconditioning step {t+1}/{T-1}:")
            print("  Current y (detached):", y.mean().item())
            print("  Current z (detached):", z.mean().item())

    # Final step with gradient (or without for inference)
    y, z = latent_recursion(x, y, z, net, n)
    y_hat = y
    q_hat = torch.sigmoid(y_hat.max(dim=-1).values)  # confidence
    # Add print statements to show final outputs
    print(f"Final step {T}/{T}:")
    print("  Final y_hat:", y_hat.mean().item())
    print("  Final q_hat:", q_hat.mean().item())
    return (y.detach(), z.detach()), y_hat, q_hat

# Example training loop with deep supervision
def train_model(train_dataloader, net, optimizer, device, N_supervision=3):
    net.train()
    for x_input, y_true in train_dataloader:
        x_input, y_true = x_input.to(device), y_true.to(device)
        batch_size = x_input.size(0)

        y = torch.zeros(batch_size, net.num_classes, device=device)
        z = torch.zeros(batch_size, net.latent_dim, device=device)

        for _ in range(N_supervision):
            x_emb = x_input  # optional embedding here
            (y_detach, z_detach), y_hat, q_hat = deep_recursion(x_emb, y, z, net)

            # Losses
            ce_loss = F.cross_entropy(y_hat, y_true)
            aux_loss = F.binary_cross_entropy(q_hat, torch.ones_like(q_hat))
            loss = ce_loss + 0.1 * aux_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            optimizer.step()

            # Update y and z for next deep supervision step
            y = y_detach
            z = z_detach

            # Optional dynamic early stopping
            if q_hat.mean() > 0.95:
                break

In [32]:
# In Colab terminal / code cell
!wget -c https://huggingface.co/datasets/ccdv/cnn_dailymail/resolve/main/cnn_stories.tgz
!wget -c https://huggingface.co/datasets/ccdv/cnn_dailymail/resolve/main/dailymail_stories.tgz

# Extract
!mkdir -p ./cnn_dailymail
!tar -xvzf cnn_stories.tgz -C ./cnn_dailymail
!tar -xvzf dailymail_stories.tgz -C ./cnn_dailymail


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
./dailymail/stories/ed8674cc15b29a87d8df8de1efee353d71122272.story
./dailymail/stories/6d7bc2f757e1975510b5009a0605b781226e57f8.story
./dailymail/stories/d77c038b6366e852ff157ff23b9db80c02e67e32.story
./dailymail/stories/40c2ac49da0284e75cc0aebaaa8c6e29ed7d7cc9.story
./dailymail/stories/b95a8098e6594b92c5c3aac80d55f884ee13a60e.story
./dailymail/stories/9ef8de0159d987fba67911dc6595dbcf667b76e2.story
./dailymail/stories/87edb976f8c49008f2bd83f9e567784616319e5b.story
./dailymail/stories/80c4b62473f179325602d5de85463991f0287167.story
./dailymail/stories/b06808437cf8b6f41b7888001ebecdcfe7de6d34.story
./dailymail/stories/95a942ed60852ec257f145d8bafd0206295dab7b.story
./dailymail/stories/c148c5858d263cd70d8d1d1de925ed8cd7615733.story
./dailymail/stories/7829822e79cb17ed920bb14d896f114961da3972.story
./dailymail/stories/e90f1740b53be02fe5174185a61a25c5106596ea.story
./dailymail/stories/eae3803b67ea2eded61ba78592c0e6577f28cdf6.sto

In [33]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BartTokenizer, BartModel
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset and model
# Increase timeout for dataset download
try:
    pass # Dataset loading is not done here anymore
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please check your internet connection or try again later.")
    # You might want to exit the cell execution here if the dataset is essential
    # raise e # Uncomment to re-raise the exception if needed

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
bart_model = BartModel.from_pretrained('facebook/bart-base').to(device).eval()

# Embedding function
def get_embeddings(texts, tokenizer, model, device, max_length=512):
    """Get mean-pooled BART encoder embeddings for a batch of texts"""
    if len(texts) == 0:
        return torch.zeros(0, model.config.d_model, device=device)

    inputs = tokenizer(
        texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    ).to(device)
    with torch.no_grad():
        outputs = model.encoder(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)  # mean pooling
    return embeddings

# Updated training function to correctly handle batch dictionary
def train_refinement(train_data_batch, net, optimizer, device, N_supervision=3):
    """
    Train recursive model for one batch using embedding-based supervision
    train_data_batch is expected to be a dictionary like {'article': [...], 'highlights': [...]}
    """
    if len(train_data_batch) == 0 or len(train_data_batch['article']) == 0:
        return 0.0  # skip empty batch

    # Prepare inputs - directly access the lists from the dictionary
    articles = train_data_batch['article']
    references = train_data_batch['highlights']

    # Get embeddings
    x_input = get_embeddings(articles, tokenizer, bart_model, device)  # document embeddings
    y_true_emb = get_embeddings(references, tokenizer, bart_model, device)  # target summary embeddings

    batch_size = x_input.size(0)

    # Ensure RecursionModel output dimension matches embedding size
    # If not, project fc_out to embedding size
    embedding_dim = y_true_emb.size(-1)
    if net.fc_out.out_features != embedding_dim:
        print(f"Adjusting fc_out output features to match embedding dimension: {embedding_dim}")
        net.fc_out = torch.nn.Linear(net.hidden_dim, embedding_dim).to(device)

    # Initialize latent vector z and output y
    # Ensure y is initialized with the correct size matching the embedding dimension
    y = torch.zeros(batch_size, embedding_dim, device=device)
    z = torch.zeros(batch_size, net.latent_dim, device=device)

    for _ in range(N_supervision):
        # Deep recursion step
        # Ensure the input dimension of the net matches the input_dim of deep_recursion
        # This was handled by ensuring x_input is correctly sized
        (y_detach, z_detach), y_hat, q_hat = deep_recursion(x_input, y, z, net)


        # Loss: MSE for embeddings + auxiliary confidence loss
        mse_loss = F.mse_loss(y_hat, y_true_emb)
        aux_loss = F.binary_cross_entropy(q_hat, torch.ones_like(q_hat))
        loss = mse_loss + 0.1 * aux_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        # Update y and z for next deep supervision step
        y = y_detach
        z = z_detach

        # Early stopping if confident
        if q_hat.mean() > 0.95:
            break

    return loss.item()

In [34]:
input_dim = bart_model.config.d_model
hidden_dim = 128
latent_dim = 64
num_classes = 768  # will be adjusted dynamically in train_refinement

net = RecursionModel(input_dim, hidden_dim, latent_dim, num_classes).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)


In [35]:
print(net)

RecursionModel(
  (fc1): Linear(in_features=1600, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_latent): Linear(in_features=128, out_features=64, bias=True)
  (fc_out): Linear(in_features=128, out_features=768, bias=True)
  (fc_aux): Linear(in_features=128, out_features=768, bias=True)
)


In [36]:
from torch.utils.data import default_collate

# Define a custom collate function to handle the dictionary output of the dataset
def custom_collate_fn(batch):
    """
    Collate function for the CNNDailyMailDataset that handles batches of dictionaries.
    """
    # Assuming each item in the batch is a dictionary with the same keys,
    # e.g., {'article': '...', 'highlights': '...'}
    # We want to return a dictionary where each key maps to a list of values from the batch
    collated_batch = {}
    for key in batch[0].keys():
        collated_batch[key] = [item[key] for item in batch]
    return collated_batch

print("Custom collate_fn defined.")

Custom collate_fn defined.


In [37]:
import os
from torch.utils.data import Dataset, DataLoader

class CNNDailyMailDataset(Dataset):
    def __init__(self, folder_path, max_samples=None):
        self.articles = []
        self.summaries = []
        cnn_path = os.path.join(folder_path, "cnn")
        dm_path = os.path.join(folder_path, "dailymail")

        # Read CNN stories
        cnn_story_path = os.path.join(cnn_path, "stories") # Navigate into the 'stories' directory
        for fname in os.listdir(cnn_story_path):
            fpath = os.path.join(cnn_story_path, fname)
            if os.path.isfile(fpath): # Check if it's a file
                with open(fpath, 'r', encoding='utf-8') as f:
                    self.articles.append(f.read())

        # Read DailyMail stories
        dm_story_path = os.path.join(dm_path, "stories") # Navigate into the 'stories' directory
        for fname in os.listdir(dm_story_path):
            fpath = os.path.join(dm_story_path, fname)
            if os.path.isfile(fpath): # Check if it's a file
                with open(fpath, 'r', encoding='utf-8') as f:
                    self.summaries.append(f.read())


        # Use subset for quick testing
        if max_samples:
            self.articles = self.articles[:max_samples]
            self.summaries = self.summaries[:max_samples]

        assert len(self.articles) == len(self.summaries), "Mismatch in number of articles and summaries"
        print(f"Loaded {len(self.articles)} samples from CNN/DailyMail dataset.")


    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        return {
            "article": self.articles[idx],
            "highlights": self.summaries[idx]
        }

# Create dataset and dataloader
dataset = CNNDailyMailDataset("/content/cnn_dailymail", max_samples=64)  # small subset for testing
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn) # Use the custom collate_fn

Loaded 64 samples from CNN/DailyMail dataset.


In [38]:
# Run training for 2 epochs for testing

if 'net' not in locals():
    print("Model 'net' not found. Please initialize the model.")
elif 'optimizer' not in locals():
    print("Optimizer 'optimizer' not found. Please initialize the optimizer.")
elif 'dataloader' not in locals():
    print("DataLoader 'dataloader' not found. Please load the dataset and create the dataloader.")
else:
    num_epochs = 20 # Set number of epochs for testing
    N_supervision = 3 # Number of deep supervision steps (adjust as needed)

    print(f"Starting training for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        net.train() # Set model to training mode
        total_loss = 0

        for batch_idx, train_data_batch in enumerate(dataloader):
            # train_data_batch is a dictionary of lists due to custom_collate_fn
            loss = train_refinement(train_data_batch, net, optimizer, device, N_supervision)
            total_loss += loss

            if (batch_idx + 1) % 10 == 0: # Print loss every 10 batches
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss:.4f}")


        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] finished, Average Loss: {avg_loss:.4f}")

    print("Training finished.")

Starting training for 20 epochs...
Preconditioning step 1/2:
  Current y (detached): -0.010231632739305496
  Current z (detached): 0.14476606249809265
Preconditioning step 2/2:
  Current y (detached): -0.009992295876145363
  Current z (detached): 0.15188957750797272
Final step 3/3:
  Final y_hat: -0.009397396817803383
  Final q_hat: 0.7716807126998901
Preconditioning step 1/2:
  Current y (detached): -0.008789755403995514
  Current z (detached): 0.14385852217674255
Preconditioning step 2/2:
  Current y (detached): -0.01055789552628994
  Current z (detached): 0.13248610496520996
Final step 3/3:
  Final y_hat: -0.008954918943345547
  Final q_hat: 0.8068772554397583
Preconditioning step 1/2:
  Current y (detached): -0.00841260701417923
  Current z (detached): 0.13443417847156525
Preconditioning step 2/2:
  Current y (detached): -0.0073801446706056595
  Current z (detached): 0.1361982822418213
Final step 3/3:
  Final y_hat: -0.008999553509056568
  Final q_hat: 0.7881201505661011
Preconditi

In [48]:
# Example inference on first batch
batch = next(iter(dataloader))
articles = batch['article'] # Correctly access the list of articles from the batch dictionary
x_input = get_embeddings(articles, tokenizer, bart_model, device)

batch_size = x_input.size(0)
y = torch.zeros(batch_size, net.fc_out.out_features, device=device)
z = torch.zeros(batch_size, net.latent_dim, device=device)

# Set model to evaluation mode
net.eval()
with torch.no_grad(): # No gradient calculation during inference
    (_, _), y_hat, q_hat = deep_recursion(x_input, y, z, net)

print("Predicted embedding shape:", y_hat.shape)
print("Prediction confidence:", q_hat)

# Optional: If you have target summaries for this batch, you could also calculate loss or similarity
references = batch['highlights']
print(references)
y_true_emb = get_embeddings(references, tokenizer, bart_model, device)
mse_loss = F.mse_loss(y_hat, y_true_emb)
print("MSE Loss with target embeddings:", mse_loss.item())

Preconditioning step 1/2:
  Current y (detached): 0.008832002989947796
  Current z (detached): 0.1603105664253235
Preconditioning step 2/2:
  Current y (detached): 0.008966365829110146
  Current z (detached): 0.18550479412078857
Final step 3/3:
  Final y_hat: 0.008964473381638527
  Final q_hat: 0.9251689910888672
Predicted embedding shape: torch.Size([8, 768])
Prediction confidence: tensor([0.9227, 0.9260, 0.9273, 0.9234, 0.9237, 0.9261, 0.9256, 0.9265],
       device='cuda:0')
MSE Loss with target embeddings: 0.009147725068032742
