import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import gzip
import io
class RoPEPositionalEncoding(nn.Module):
"""
Rotary Position Encoding (RoPE) module.
"""
def __init__(self, d_model, max_len=5000, base=10000):
super().__init__()
self.d_model = d_model
self.max_len = max_len
self.base = base
self.theta = torch.tensor(
[base ** (-2 * (i // 2) / d_model) for i in range(d_model)]
)
def forward(self, positions):
angles = positions.unsqueeze(-1) * self.theta
return torch.stack([angles.cos(), angles.sin()], dim=-1).flatten(-2)
def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat):
"""
Perform non-uniform interpolation on position embeddings.
Args:
pos_embed (torch.Tensor): Position embeddings.
extension_ratio (float): Extension ratio for context window.
lambda_factors (list): Lambda factors for interpolation.
n_hat (int): Threshold for applying interpolation.
Returns:
torch.Tensor: Interpolated position embeddings.
"""
d_model = pos_embed.shape[-1]
interpolated_pos = pos_embed.clone()
for i in range(d_model // 2):
mask = torch.arange(pos_embed.shape[-2]) < n_hat
scale = torch.where(
mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i]
)
interpolated_pos[..., i * 2] *= scale
interpolated_pos[..., i * 2 + 1] *= scale
return interpolated_pos
def search_lambda_factors(
model,
data,
extension_ratio,
population_size,
num_mutations,
num_crossovers,
max_iterations,
):
"""
Search for optimal lambda factors using evolutionary search.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
extension_ratio (float): Extension ratio for context window.
population_size (int): Size of the population for evolutionary search.
num_mutations (int): Number of mutations per iteration.
num_crossovers (int): Number of crossovers per iteration.
max_iterations (int): Maximum number of iterations for evolutionary search.
Returns:
list: Optimal lambda factors found by the search.
"""
population = initialize_population(population_size, extension_ratio)
for i in range(max_iterations):
perplexities = evaluate_population(model, data, population)
parents = select_topk(population, perplexities, k=population_size // 2)
population = mutate(parents, num_mutations) + crossover(parents, num_crossovers)
return min(population, key=lambda x: evaluate_individual(model, data, x))
def progressive_extension(model, data, base_length, target_length):
"""
Progressively extend the context window of the model.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
base_length (int): Base context window length.
target_length (int): Target context window length.
Returns:
tuple: (Extended model, lambda factors, base lambda factors)
"""
curr_model = model
curr_length = base_length
while curr_length < target_length:
lambda_factors, n_hat = search_lambda_factors(
curr_model, data, curr_length / base_length
)
curr_model = fine_tune(curr_model, data, curr_length, lambda_factors, n_hat)
curr_length *= 2
lambda_factors_base, _ = search_lambda_factors(
curr_model, data, curr_length / base_length, max_length=base_length
)
return curr_model, lambda_factors, lambda_factors_base
class LongRoPEModel(nn.Module):
"""
Long Range Rotary Position Encoding (LongRoPE) model.
This model extends the context window of transformer-based models beyond the
typical limit by using non-uniform interpolation of rotary position embeddings.
It enables the model to handle longer input sequences while maintaining the
ability to capture long-range dependencies.
Attributes:
d_model (int): Dimension of the model.
n_heads (int): Number of attention heads.
num_layers (int): Number of transformer layers.
max_len (int): Maximum sequence length.
rope (RoPEPositionalEncoding): Rotary Position Encoding (RoPE) module.
transformers (nn.ModuleList): List of transformer encoder layers.
lambda_factors (list): Lambda factors for non-uniform interpolation.
lambda_factors_base (list): Lambda factors for the base model.
extension_ratio (float): Extension ratio for the context window.
n_hat (int): Threshold for applying interpolation.
Methods:
forward(input_ids):
Perform forward pass on the input sequence.
Args:
input_ids (torch.Tensor): Input sequence tensor.
Returns:
torch.Tensor: Output embeddings from the model.
extend_context(data_path, target_length, max_sequence_length, tokenizer):
Extend the context window of the model.
Args:
data_path (str): Path to the input data file.
target_length (int): Target context window length.
max_sequence_length (int): Maximum sequence length for input data.
tokenizer: Tokenizer object for encoding input data.
Returns:
LongRoPEModel: Extended LongRoPE model.
"""
def __init__(self, d_model, n_heads, num_layers, max_len=5000):
super().__init__()
self.d_model = d_model
self.num_layers = num_layers
self.rope = RoPEPositionalEncoding(d_model, max_len)
self.transformers = nn.ModuleList(
[
nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)
for _ in range(num_layers)
]
)
self.lambda_factors = None
self.lambda_factors_base = None
def forward(self, input_ids):
positions = torch.arange(input_ids.size(1), device=input_ids.device)
pos_embeddings = self.rope(positions)
if self.lambda_factors is not None:
pos_embeddings = non_uniform_interpolation(
pos_embeddings, self.extension_ratio, self.lambda_factors, self.n_hat
)
input_embeddings = input_ids + pos_embeddings
for transformer in self.transformers:
input_embeddings = transformer(input_embeddings)
return input_embeddings
def extend_context(self, data_path, target_length, max_sequence_length, tokenizer):
"""
Extend the context window of the model.
Args:
data_path (str): Path to the input data file.
target_length (int): Target context window length.
max_sequence_length (int): Maximum sequence length for input data.
tokenizer: Tokenizer object for encoding input data.
Returns:
LongRoPEModel: Extended LongRoPE model.
"""
if tokenizer is None:
raise ValueError("Tokenizer is required for extending context.")
self.extension_ratio = target_length / self.rope.max_len
data = load_data(data_path, tokenizer, max_sequence_length)
model, lambda_factors, lambda_factors_base = progressive_extension(
self, data, self.rope.max_len, target_length
)
self.lambda_factors = lambda_factors
self.lambda_factors_base = lambda_factors_base
self.n_hat = self.rope.max_len // 2
return model
def load_data(data_path, tokenizer, max_sequence_length):
"""
Load and preprocess the input data.
Args:
data_path (str): Path to the input data file.
tokenizer: Tokenizer object for encoding input data.
max_sequence_length (int): Maximum sequence length for input data.
Returns:
list: List of preprocessed input sequences.
"""
if data_path is None or tokenizer is None:
raise ValueError("Data path and tokenizer are required for loading data.")
if data_path.endswith(".gz"):
with gzip.open(data_path, "rt", encoding="utf-8") as file:
text_data = file.read()
else:
with open(data_path, "r", encoding="utf-8") as file:
text_data = file.read()
tokenized_data = tokenizer.encode(text_data)
sequences = [
tokenized_data[i : i + max_sequence_length]
for i in range(0, len(tokenized_data), max_sequence_length)
]
tensor_data = [torch.tensor(seq, dtype=torch.long) for seq in sequences]
return tensor_data
def initialize_population(population_size, extension_ratio):
"""
Initialize the population for evolutionary search.
Args:
population_size (int): Size of the population.
extension_ratio (float): Extension ratio for context window.
Returns:
list: Initialized population.
"""
population = []
population.append(torch.ones(512) * extension_ratio)
ntk_factors = torch.tensor([extension_ratio ** (2 * i / 512) for i in range(512)])
population.append(ntk_factors)
yarn_factors = torch.ones(512)
yarn_factors[:128] = 1.0
yarn_factors[128:256] = extension_ratio ** (1 / 3)
yarn_factors[256:] = extension_ratio
population.append(yarn_factors)
for _ in range(population_size - 3):
factors = torch.ones(512)
for i in range(512):
if random.random() < 0.1:
factors[i] = random.uniform(1, extension_ratio)
population.append(factors)
return population
def evaluate_individual(model, data, individual):
"""
Evaluate an individual lambda factor configuration.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
individual (list): Lambda factor configuration.
Returns:
float: Perplexity score for the individual.
"""
model.lambda_factors = individual
perplexities = []
for seq in data:
input_ids = seq.unsqueeze(0)
output = model(input_ids)
perplexity = torch.exp(torch.mean(output))
perplexities.append(perplexity.item())
return np.mean(perplexities)
def evaluate_population(model, data, population):
"""
Evaluate the population of lambda factor configurations.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
population (list): Population of lambda factor configurations.
Returns:
list: Perplexity scores for each individual in the population.
"""
perplexities = []
for individual in population:
perplexity = evaluate_individual(model, data, individual)
perplexities.append(perplexity)
return perplexities
def select_topk(population, perplexities, k):
"""
Select the top-k individuals from the population based on perplexity scores.
Args:
population (list): Population of lambda factor configurations.
perplexities (list): Perplexity scores for each individual in the population.
k (int): Number of top individuals to select.
Returns:
list: Top-k individuals from the population.
"""
indices = np.argsort(perplexities)[:k]
return [population[i] for i in indices]
def mutate(parents, num_mutations):
"""
Perform mutation on the parent population.
Args:
parents (list): Parent population.
num_mutations (int): Number of mutations to perform.
Returns:
list: Mutated population.
"""
mutated_population = []
for _ in range(num_mutations):
parent = random.choice(parents)
child = parent.clone()
for i in range(512):
if random.random() < 0.1:
child[i] *= random.uniform(0.8, 1.2)
mutated_population.append(child)
return mutated_population
def crossover(parents, num_crossovers):
"""
Perform crossover on the parent population.
Args:
parents (list): Parent population.
num_crossovers (int): Number of crossovers to perform.
Returns:
list: Crossover population.
"""
crossover_population = []
for _ in range(num_crossovers):
parent1, parent2 = random.sample(parents, 2)
child = parent1.clone()
for i in range(512):
if random.random() < 0.5:
child[i] = parent2[i]
crossover_population.append(child)
return crossover_population
def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3):
"""
Fine-tune the LongRoPE model.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
target_length (int): Target context window length.
lambda_factors (list): Lambda factors for interpolation.
n_hat (int): Threshold for applying interpolation.
num_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3.
Returns:
nn.Module: Fine-tuned LongRoPE model.
"""
model.lambda_factors = lambda_factors
model.n_hat = n_hat
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for seq in data:
optimizer.zero_grad()
seq_len = seq.size(0)
if seq_len <= target_length:
input_ids = seq.unsqueeze(0)
else:
start_idx = random.randint(0, seq_len - target_length)
input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0)
output = model(input_ids)
loss = torch.mean(output)
loss.backward()
optimizer.step()
return model
# Example usage
data_path = "path/to/your/dataset"
d_model = 512
n_heads = 8
num_layers = 6
base_length = 4096
target_length = 2048 * 1024
data = load_data(data_path)
model = LongRoPEModel(d_model, n_heads, num_layers, base_length)
model = model.extend_context(data, target_length)
input_ids = torch.randn(2, target_length, d_model)
output = model(input_ids)
print(output.shape) # Expected shape: (batch_size, target_length, d_model)
dad