import os
import sys
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import timm
from pathlib import Path
import kaggle
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from wildlife_datasets import loader, datasets, splits
from wildlife_tools.data import ImageDataset
from wildlife_tools.features import DeepFeatures
from wildlife_tools.similarity import CosineSimilarity
from wildlife_tools.inference import KnnClassifier
from wildlife_tools.train import ArcFaceLoss, set_seed
from tqdm import tqdm
import random
from gcn_reid.segmentation import decode_rle_mask, visualize_segmentation, visualize_segmentation_from_metadata
from gcn_reid.attribution import my_occlusion_sensitivity
import itertools
print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
BSL Experiment
Finetuning MegaDescriptor with Background Supporession Loss (BSL)
# Download and verify dataset
def download_newt_dataset():
= "mshahoyi/barhill-newts-segmented"
dataset_name = "data/newt_dataset"
download_path
if not os.path.exists(download_path):
=True)
os.makedirs(download_path, exist_ok=download_path, unzip=True)
kaggle.api.dataset_download_files(dataset_name, pathprint(f"Dataset downloaded to {download_path}")
else:
print(f"Dataset already exists at {download_path}")
return download_path
= download_newt_dataset()
dataset_path
# Verify dataset structure
print(f"\nDataset path: {dataset_path}")
print("Dataset contents:")
for item in os.listdir(dataset_path):
print(f" {item}")
# Load and examine metadata
= os.path.join(dataset_path, "metadata.csv")
metadata_path = pd.read_csv(metadata_path)
df
print(f"Dataset contains {len(df)} images")
print(f"Number of unique newts: {df['newt_id'].nunique()}")
print(f"Dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
print("\nFirst few rows:")
print(df.head())
print("\nNewt ID distribution:")
print(df['newt_id'].value_counts().head(10))
# Test RLE decoding with a sample
= df.iloc[0]
sample_row print(f"Testing RLE decoding with sample:")
print(f"Image path: {sample_row['image_path']}")
print(f"Newt ID: {sample_row['newt_id']}")
# Load sample image to get dimensions
= Path(dataset_path) / sample_row['image_path']
sample_img_path = Image.open(sample_img_path)
sample_img print(f"Image size: {sample_img.size}")
# Decode mask
= sample_img.size[1], sample_img.size[0]
h, w = decode_rle_mask(sample_row['segmentation_mask_rle'])
mask print(f"Mask shape: {mask.shape}")
print(f"Mask unique values: {np.unique(mask)}")
print(f"Mask coverage: {mask.sum() / mask.size:.2%}")
# Visualize sample
= plt.subplots(1, 3, figsize=(15, 5))
fig, axes 0].imshow(sample_img)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[
1].imshow(mask, cmap='gray')
axes[1].set_title('Segmentation Mask')
axes[1].axis('off')
axes[
2].imshow(sample_img)
axes[2].imshow(mask, alpha=0.5, cmap='Reds')
axes[2].set_title('Image + Mask Overlay')
axes[2].axis('off')
axes[
plt.tight_layout() plt.show()
# Create and test dataset splits
def create_newt_splits(df, train_ratio=0.8, random_state=42):
"""Create disjoint splits ensuring each newt appears in only one split"""
= df['newt_id'].unique()
unique_newts
= train_test_split(
train_newts, test_newts
unique_newts, =train_ratio,
train_size=random_state,
random_state=None
stratify
)
= df[df['newt_id'].isin(train_newts)].copy()
df_train = df[df['newt_id'].isin(test_newts)].copy()
df_test
print(f"Train split: {len(df_train)} images from {len(train_newts)} newts")
print(f"Test split: {len(df_test)} images from {len(test_newts)} newts")
return df_train, df_test
= create_newt_splits(df)
df_train, df_test
# Verify no overlap between train and test
= set(df_train['newt_id'].unique())
train_newts = set(df_test['newt_id'].unique())
test_newts = train_newts.intersection(test_newts)
overlap print(f"Overlap between train and test newts: {len(overlap)} (should be 0)")
print("\nTrain newt distribution (top 10):")
print(df_train['newt_id'].value_counts().head(10))
print("\nTest newt distribution (top 10):")
print(df_test['newt_id'].value_counts().head(10))
# Create and test custom dataset class
class NewtDataset(Dataset):
def __init__(self, dataframe, root_path, transform=None, return_mask=True):
self.df = dataframe.reset_index(drop=True)
self.root_path = Path(root_path)
self.transform = transform
self.return_mask = return_mask
self.labels_string = self.df['newt_id'].astype(str).tolist()
# Create label mapping
= sorted(self.df['newt_id'].unique())
unique_labels self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
self.labels = [self.label_to_idx[label] for label in self.df['newt_id']]
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
= self.df.iloc[idx]
row
# Load image
= self.root_path / row['image_path']
img_path = Image.open(img_path).convert('RGB')
image
# Get label
= self.labels[idx]
label
# Decode segmentation mask
= None
mask if self.return_mask and 'segmentation_mask_rle' in row:
= image.size[1], image.size[0]
h, w try:
= decode_rle_mask(row['segmentation_mask_rle'])
decoded_mask if decoded_mask is not None:
= Image.fromarray(decoded_mask * 255).convert('L')
mask else:
# Create a default mask (all foreground) when RLE decoding fails
= Image.fromarray(np.ones((h, w), dtype=np.uint8) * 255).convert('L')
mask except Exception as e:
# Create a default mask if there's any error in decoding
= image.size[1], image.size[0]
h, w = Image.fromarray(np.ones((h, w), dtype=np.uint8) * 255).convert('L')
mask print(f"Warning: Error decoding mask for image {idx}: {e}, using default full mask")
# Apply transforms
if self.transform:
if mask is not None:
# Apply same transform to both image and mask
= np.random.randint(2147483647)
seed
random.seed(seed)
torch.manual_seed(seed)= self.transform(image)
image
random.seed(seed)
torch.manual_seed(seed)= T.ToTensor()(mask)
mask = T.Resize(image.shape[-2:])(mask)
mask else:
= self.transform(image)
image
if mask is not None:
return image, label, mask.squeeze(0)
else:
return image, label
# Test dataset creation
= T.Compose([
transform_test 224, 224]),
T.Resize([
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean
])
= NewtDataset(df_train.head(10), dataset_path, transform=transform_test, return_mask=True)
test_dataset_small
print(f"Test dataset size: {len(test_dataset_small)}")
print(f"Number of classes in test: {len(test_dataset_small.label_to_idx)}")
print(f"Label mapping: {test_dataset_small.label_to_idx}")
# Test loading a sample
= test_dataset_small[0]
sample_data print(f"Sample data shapes:")
print(f" Image: {sample_data[0].shape}")
print(f" Label: {sample_data[1]}")
print(f" Mask: {sample_data[2].shape}")
# Create full datasets with sanity checks
= T.Compose([
transform_train 224, 224]),
T.Resize([=0.5),
T.RandomHorizontalFlip(p=180),
T.RandomRotation(degrees=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.ColorJitter(brightness
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean
])
= NewtDataset(df_train, dataset_path, transform=transform_train, return_mask=True)
train_dataset = NewtDataset(df_test, dataset_path, transform=transform_test, return_mask=True)
test_dataset
print(f"Training dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")
print(f"Number of classes: {len(train_dataset.label_to_idx)}")
# Verify datasets
= train_dataset[0]
train_sample = test_dataset[0]
test_sample
print(f"\nTrain sample shapes: image={train_sample[0].shape}, label={train_sample[1]}, mask={train_sample[2].shape}")
print(f"Test sample shapes: image={test_sample[0].shape}, label={test_sample[1]}, mask={test_sample[2].shape}")
# Check label consistency
print(f"Train labels range: {min(train_dataset.labels)} to {max(train_dataset.labels)}")
print(f"Test labels range: {min(test_dataset.labels)} to {max(test_dataset.labels)}")
# Test model loading and feature extraction
def test_megadescriptor_loading():
print("Testing MegaDescriptor loading...")
# Test loading the model
= 'hf-hub:BVRA/MegaDescriptor-T-224'
model_name = timm.create_model(model_name, pretrained=True, num_classes=0)
backbone
# Test forward pass
with torch.no_grad():
= torch.randn(2, 3, 224, 224)
dummy_input = backbone(dummy_input)
features print(f"Model loaded successfully!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output features shape: {features.shape}")
print(f"Feature dimension: {features.shape[1]}")
return backbone, features.shape[1]
= test_megadescriptor_loading()
backbone, embedding_size backbone
# Create and test ArcFace loss
def test_arcface_loss():
print("Testing ArcFace loss...")
= len(train_dataset.label_to_idx)
num_classes
# Create ArcFace loss
= ArcFaceLoss(
arcface_loss =num_classes,
num_classes=embedding_size,
embedding_size=0.5,
margin=64
scale
)
print(f"ArcFace loss created for {num_classes} classes, embedding size {embedding_size}")
# Test forward pass
with torch.no_grad():
= torch.randn(4, embedding_size)
dummy_embeddings = torch.randint(0, num_classes, (4,))
dummy_labels
= arcface_loss(dummy_embeddings, dummy_labels)
loss print(f"Test loss: {loss.item():.4f}")
print("ArcFace loss working correctly!")
return arcface_loss
= test_arcface_loss() arcface_loss
# Define Background Suppression ArcFace Loss
class BackgroundSuppressionArcFaceLoss(nn.Module):
"""
Custom loss that combines ArcFace loss with background suppression
Uses segmentation masks to focus learning on the newt regions
"""
def __init__(self, num_classes, embedding_size, margin=0.5, scale=64, alpha=1.0, beta=0.5):
super().__init__()
self.arcface_loss = ArcFaceLoss(
=num_classes,
num_classes=embedding_size,
embedding_size=margin,
margin=scale
scale
)self.alpha = alpha # Weight for ArcFace loss
self.beta = beta # Weight for background suppression loss
def forward(self, embeddings, labels, masks, patch_features=None):
"""
Args:
embeddings: Output embeddings from the backbone [B, embedding_size]
labels: Ground truth labels [B]
masks: Binary segmentation masks (1 for newt, 0 for background) [B, H, W]
patch_features: Intermediate feature maps for background suppression [B, C, Hf, Wf]
"""
# ArcFace loss on embeddings
= self.arcface_loss(embeddings, labels)
arcface_loss
# Background suppression loss
= torch.tensor(0.0, device=embeddings.device)
background_penalty
if patch_features is not None and masks is not None:
= patch_features.shape
B, C, Hf, Wf
print(f"Patch features shape: {patch_features.shape}")
# Resize masks to match feature map size
= F.interpolate(
masks_resized 1).float(),
masks.unsqueeze(=(Hf, Wf),
size='nearest'
mode1)
).squeeze(
# Background mask (1 for background, 0 for foreground)
= 1.0 - masks_resized
background_mask
# Compute L2 norm of patch features
= patch_features.pow(2).sum(1).sqrt() # [B, Hf, Wf]
patch_norm
# Background suppression: penalize high activations in background regions
= (patch_norm * background_mask).mean()
background_penalty
= self.alpha * arcface_loss + self.beta * background_penalty
total_loss
return total_loss, arcface_loss, background_penalty
# Test BSL Loss
def test_bsl_loss():
print("Testing Background Suppression ArcFace Loss...")
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = len(train_dataset.label_to_idx)
num_classes
= BackgroundSuppressionArcFaceLoss(
bsl_loss =num_classes,
num_classes=embedding_size,
embedding_size=0.5,
margin=64,
scale=1.0,
alpha=0.5
beta
).to(device)
# Test with dummy data
with torch.no_grad():
# Get real data samples from training dataset
= next(iter(DataLoader(train_dataset, batch_size=2, shuffle=True)))
sample_batch = sample_batch
images, labels, masks = images.to(device), labels.to(device), masks.to(device)
images, labels, masks
# Get embeddings and patch features from model
eval()
model.= model(images)
embeddings = model.patch_features
patch_features
= bsl_loss(
total_loss, arcface_loss, bg_loss
embeddings, labels, masks, patch_features
)
print(f"Total loss: {total_loss.item():.4f}")
print(f"ArcFace loss: {arcface_loss.item():.4f}")
print(f"Background loss: {bg_loss.item():.4f}")
return bsl_loss
= test_bsl_loss() bsl_loss
= timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', pretrained=True, num_classes=0)
model model
# Create model with feature extraction hooks
class MegaDescriptorWithBSL(nn.Module):
def __init__(self, num_classes, model_name='hf-hub:BVRA/MegaDescriptor-T-224'):
super().__init__()
# Load pretrained MegaDescriptor backbone
self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
# Get feature dimension
with torch.no_grad():
= torch.randn(1, 3, 224, 224)
dummy_input = self.backbone(dummy_input)
features self.embedding_size = features.shape[1]
# Store intermediate features for BSL
self.patch_features = None
# Register hook to capture intermediate features
self._register_hooks()
def _register_hooks(self):
"""Register hooks to capture intermediate feature maps"""
def hook_fn(module, input, output):
# Swin Transformer outputs features in [B, H, W, C] format
if len(output.shape) == 4:
= output.shape
B, H, W, C # Convert to [B, C, H, W] format for compatibility
self.patch_features = output.permute(0, 3, 1, 2)
elif len(output.shape) == 3:
# Some layers might output [B, N, C], try to reshape
= output.shape
B, N, C = W = int(np.sqrt(N))
H if H * W == N:
self.patch_features = output.view(B, H, W, C).permute(0, 3, 1, 2)
# Hook into one of the later Swin Transformer stages
# Stage 2 has 384 channels and good spatial resolution
# Stage 3 has 768 channels (final) but lower spatial resolution
# Try to hook into stage 2 (layers.2) - 384 channels
if hasattr(self.backbone, 'layers') and len(self.backbone.layers) > 2:
= self.backbone.layers[2] # Stage 2
target_stage print(f"Hooking to Swin stage 2 with {384} channels")
target_stage.register_forward_hook(hook_fn)return
# Fallback: try to hook into any layer with 'layers' in the name
= False
hooked for name, module in self.backbone.named_modules():
if 'layers.2' in name and not hooked: # Prefer stage 2
print(f"Hooking to layer: {name}")
module.register_forward_hook(hook_fn)= True
hooked break
elif 'layers.1' in name and not hooked: # Fallback to stage 1
print(f"Hooking to layer: {name}")
module.register_forward_hook(hook_fn)= True
hooked break
if not hooked:
print("Warning: Could not find suitable Swin Transformer layer to hook")
def forward(self, x):
# Reset patch features
self.patch_features = None
# Forward through backbone
= self.backbone(x)
embeddings
return embeddings
def get_patch_features(self):
"""Get the stored patch features for background suppression"""
return self.patch_features
# Test model creation
def test_model_creation():
print("Testing model creation...")
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = len(train_dataset.label_to_idx)
num_classes
= MegaDescriptorWithBSL(num_classes).to(device)
model
print(f"Model created with {model.embedding_size} embedding size")
# Test forward pass
with torch.no_grad():
= torch.randn(2, 3, 224, 224).to(device)
dummy_input = model(dummy_input)
embeddings = model.get_patch_features()
patch_features
print(f"Input shape: {dummy_input.shape}")
print(f"Embeddings shape: {embeddings.shape}")
if patch_features is not None:
print(f"Patch features shape: {patch_features.shape}")
else:
print("Warning: No patch features captured")
return model
= test_model_creation() model
# Test data loading with actual data
def test_data_loading():
print("Testing data loading...")
# Create small data loaders for testing
= NewtDataset(df_train.head(20), dataset_path, transform=transform_train, return_mask=True)
small_train_dataset = DataLoader(small_train_dataset, batch_size=4, shuffle=True, num_workers=0)
train_loader
# Test loading one batch
for batch_idx, batch in enumerate(train_loader):
if len(batch) == 3:
= batch
images, labels, masks print(f"Batch {batch_idx}:")
print(f" Images shape: {images.shape}")
print(f" Labels: {labels}")
print(f" Masks shape: {masks.shape}")
print(f" Mask value ranges: {masks.min().item():.3f} to {masks.max().item():.3f}")
# Visualize one sample from batch
= images[0]
img = masks[0]
mask
# Denormalize image for visualization
= img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
img_denorm = torch.clamp(img_denorm, 0, 1)
img_denorm
= plt.subplots(1, 2, figsize=(10, 5))
fig, axes 0].imshow(img_denorm.permute(1, 2, 0))
axes[0].set_title(f'Image (Label: {labels[0].item()})')
axes[0].axis('off')
axes[
1].imshow(mask, cmap='gray')
axes[1].set_title('Mask')
axes[1].axis('off')
axes[
plt.tight_layout()
plt.show()
break
return train_loader
= test_data_loading() test_loader
# Test full training setup
def test_training_setup():
print("Testing full training setup...")
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device print(f"Using device: {device}")
# Model
= len(train_dataset.label_to_idx)
num_classes = MegaDescriptorWithBSL(num_classes).to(device)
model
# Loss function
= BackgroundSuppressionArcFaceLoss(
bsl_loss =num_classes,
num_classes=model.embedding_size,
embedding_size=0.5,
margin=64,
scale=1.0,
alpha=0.5
beta
).to(device)
# Optimizer
= itertools.chain(model.parameters(), bsl_loss.parameters())
params = torch.optim.AdamW(params, lr=1e-4, weight_decay=1e-4)
optimizer
# Test one training step
model.train()
bsl_loss.train()
# Get a small batch
= NewtDataset(df_train.head(8), dataset_path, transform=transform_train, return_mask=True)
small_dataset = DataLoader(small_dataset, batch_size=4, shuffle=True, num_workers=0)
loader
for batch in loader:
= batch
images, labels, masks = images.to(device)
images = labels.to(device)
labels = masks.to(device)
masks
print(f"Batch shapes - Images: {images.shape}, Labels: {labels.shape}, Masks: {masks.shape}")
optimizer.zero_grad()
# Forward pass
= model(images)
embeddings = model.get_patch_features()
patch_features
print(f"Embeddings shape: {embeddings.shape}")
if patch_features is not None:
print(f"Patch features shape: {patch_features.shape}")
# Compute loss
= bsl_loss(embeddings, labels, masks, patch_features)
loss, arcface_loss, bg_loss
print(f"Losses - Total: {loss.item():.4f}, ArcFace: {arcface_loss.item():.4f}, BG: {bg_loss.item():.4f}")
# Backward pass
loss.backward()
optimizer.step()
print("Training step completed successfully!")
break
return model, bsl_loss, optimizer
= test_training_setup() model, bsl_loss, optimizer
# Now we can proceed with the actual training
print("Setup complete! Ready for full training...")
print(f"Total training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")
print(f"Number of classes: {len(train_dataset.label_to_idx)}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
# Test and define training epoch function with sanity checks
def train_epoch(model, train_loader, bsl_loss, optimizer, device, epoch):
model.train()
bsl_loss.train()
= 0
total_loss = 0
total_arcface_loss = 0
total_bg_loss = 0
correct = 0
total
= tqdm(train_loader, desc=f'Epoch {epoch}')
pbar
for batch_idx, batch in enumerate(pbar):
if len(batch) == 3: # With masks
= batch
images, labels, masks = masks.to(device)
masks else: # Without masks
= batch
images, labels = torch.ones(images.shape[0], images.shape[2], images.shape[3]).to(device)
masks
= images.to(device), labels.to(device)
images, labels
optimizer.zero_grad()
# Forward pass
= model(images)
embeddings = model.get_patch_features()
patch_features
# Compute BSL loss with ArcFace
= bsl_loss(embeddings, labels, masks, patch_features)
loss, arcface_loss, bg_loss
# Backward pass
loss.backward()
optimizer.step()
# Statistics
+= loss.item()
total_loss += arcface_loss.item()
total_arcface_loss += bg_loss.item()
total_bg_loss
# For accuracy calculation, get predictions from ArcFace weights
with torch.no_grad():
# Access the classifier weights from the pytorch_metric_learning ArcFace loss
= bsl_loss.arcface_loss.loss.W # The classifier weights
W # Normalize embeddings and weights for cosine similarity
= F.normalize(embeddings, p=2, dim=1)
embeddings_norm = F.normalize(W, p=2, dim=0)
W_norm # Compute logits as cosine similarity * scale
= F.linear(embeddings_norm, W_norm.T) * bsl_loss.arcface_loss.loss.scale
logits
= logits.max(1)
_, predicted += labels.size(0)
total += predicted.eq(labels).sum().item()
correct
# Update progress bar
pbar.set_postfix({'Loss': f'{loss.item():.4f}',
'Arc': f'{arcface_loss.item():.4f}',
'BG': f'{bg_loss.item():.4f}',
'Acc': f'{100.*correct/total:.2f}%'
})
= total_loss / len(train_loader)
avg_loss = total_arcface_loss / len(train_loader)
avg_arcface_loss = total_bg_loss / len(train_loader)
avg_bg_loss = 100. * correct / total
accuracy
return avg_loss, avg_arcface_loss, avg_bg_loss, accuracy
# Test the training epoch function with a tiny dataset
print("Testing training epoch function...")
# Create a tiny test dataset
= NewtDataset(df_train.head(16), dataset_path, transform=transform_train, return_mask=True)
tiny_dataset = DataLoader(tiny_dataset, batch_size=4, shuffle=True, num_workers=0)
tiny_loader
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
# Test training epoch
= train_epoch(
test_loss, test_arcface, test_bg, test_acc =0
model, tiny_loader, bsl_loss, optimizer, device, epoch
)
print(f"✅ Training epoch test passed!")
print(f" Loss: {test_loss:.4f} (ArcFace: {test_arcface:.4f}, BG: {test_bg:.4f})")
print(f" Accuracy: {test_acc:.2f}%")
# Test and define evaluation function
def evaluate(model, test_loader, bsl_loss, device):
eval()
model.eval()
bsl_loss.
= 0
correct = 0
total
with torch.no_grad():
for batch in tqdm(test_loader, desc='Evaluating'):
if len(batch) == 3: # With masks
= batch
images, labels, _ else: # Without masks
= batch
images, labels
= images.to(device), labels.to(device)
images, labels
# Get embeddings
= model(images)
embeddings
# Get predictions from ArcFace weights
= bsl_loss.arcface_loss.loss.W # The classifier weights
W # Normalize embeddings and weights for cosine similarity
= F.normalize(embeddings, p=2, dim=1)
embeddings_norm = F.normalize(W, p=2, dim=0)
W_norm # Compute logits as cosine similarity * scale
= F.linear(embeddings_norm, W_norm.T) * bsl_loss.arcface_loss.loss.scale
logits
= logits.max(1)
_, predicted += labels.size(0)
total += predicted.eq(labels).sum().item()
correct
= 100. * correct / total
accuracy return accuracy
# Test evaluation function
print("Testing evaluation function...")
= NewtDataset(df_test.head(16), dataset_path, transform=transform_test, return_mask=True)
tiny_test_dataset = DataLoader(tiny_test_dataset, batch_size=4, shuffle=False, num_workers=0)
tiny_test_loader
= evaluate(model, tiny_test_loader, bsl_loss, device)
eval_acc print(f"✅ Evaluation test passed!")
print(f" Test accuracy: {eval_acc:.2f}%")
# Test occlusion sensitivity function with detailed checks
def run_occlusion_sensitivity_test(model, bsl_loss, dataset, device, epoch, save_dir):
"""Run occlusion sensitivity on pairs of different newts to test similarity"""
print(f"Starting occlusion sensitivity test for epoch {epoch}")
eval()
model.eval()
bsl_loss.
# Create save directory
= Path(save_dir) / f"epoch_{epoch:03d}"
epoch_dir =True, exist_ok=True)
epoch_dir.mkdir(parentsprint(f"Created directory: {epoch_dir}")
# Find pairs of different newts
= {}
newt_indices_by_id for idx in range(len(dataset)):
= dataset.labels_string[idx]
newt_id if newt_id not in newt_indices_by_id:
= []
newt_indices_by_id[newt_id]
newt_indices_by_id[newt_id].append(idx)
# Select 2 pairs of different newts
= list(newt_indices_by_id.keys())
newt_ids if len(newt_ids) < 2:
print("Not enough different newts for similarity testing")
return
# Create similarity model function
def similarity_model(image1, image2):
"""
Compute cosine similarity between two images using the trained model
Args:
image1: First image tensor [1, C, H, W]
image2: Second image tensor [1, C, H, W]
Returns:
Cosine similarity score as tensor
"""
with torch.no_grad():
# Get embeddings for both images
= model(image1)
emb1 = model(image2)
emb2
# Compute cosine similarity
= F.normalize(emb1, p=2, dim=1)
emb1_norm = F.normalize(emb2, p=2, dim=1)
emb2_norm = F.cosine_similarity(emb1_norm, emb2_norm, dim=1)
similarity
return similarity
# Test 2 pairs
for pair_idx in range(2):
try:
# Select two different newts
= random.sample(newt_ids, 2)
newt_id1, newt_id2
# Get one image from each newt
= random.choice(newt_indices_by_id[newt_id1])
idx1 = random.choice(newt_indices_by_id[newt_id2])
idx2
print(f" Processing pair {pair_idx+1}: Newt {newt_id1} (idx {idx1}) vs Newt {newt_id2} (idx {idx2})")
# Get the images
if len(dataset[idx1]) == 3:
= dataset[idx1]
image1, label1, _ else:
= dataset[idx1]
image1, label1
if len(dataset[idx2]) == 3:
= dataset[idx2]
image2, label2, _ else:
= dataset[idx2]
image2, label2
print(f" Image1: {image1.shape}, Label1: {label1}")
print(f" Image2: {image2.shape}, Label2: {label2}")
# Convert images to tensors for similarity computation (add batch dimension)
= image1.unsqueeze(0).to(device) # [1, C, H, W]
image1_tensor = image2.unsqueeze(0).to(device) # [1, C, H, W]
image2_tensor
# Test baseline similarity
= similarity_model(image1_tensor, image2_tensor).item()
baseline_similarity print(f" Baseline similarity: {baseline_similarity:.4f}")
# Run occlusion sensitivity using your function with tensors
print(f" Running occlusion sensitivity...")
= my_occlusion_sensitivity(
occlusion_map
similarity_model,
image1_tensor,
image2_tensor, =16,
patch_size=8,
stride=0.5,
occlusion_value=device
device
)print(f" Occlusion map shape: {occlusion_map.shape}, range [{occlusion_map.min():.3f}, {occlusion_map.max():.3f}]")
# Convert images to numpy for visualization only
= image1.cpu().numpy().transpose(1, 2, 0)
img1_np = (img1_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
img1_np = np.clip(img1_np, 0, 1)
img1_np
= image2.cpu().numpy().transpose(1, 2, 0)
img2_np = (img2_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
img2_np = np.clip(img2_np, 0, 1)
img2_np
# Save visualization
= epoch_dir / f"similarity_pair_{pair_idx+1}_newt_{newt_id1}_vs_{newt_id2}.png"
save_path = plt.subplots(2, 2, figsize=(12, 10))
fig, axes
# Original images
0,0].imshow(img1_np)
axes[0,0].set_title(f'Image 1: Newt {newt_id1}\n(Index {idx1})')
axes[0,0].axis('off')
axes[
0,1].imshow(img2_np)
axes[0,1].set_title(f'Image 2: Newt {newt_id2}\n(Index {idx2})')
axes[0,1].axis('off')
axes[
# Occlusion sensitivity overlay
1,0].imshow(img1_np)
axes[1,0].imshow(occlusion_map, cmap='hot', alpha=0.6)
axes[1,0].set_title(f'Occlusion Sensitivity on Image 1\n(Similarity: {baseline_similarity:.3f})')
axes[1,0].axis('off')
axes[
# Pure occlusion map
= axes[1,1].imshow(occlusion_map, cmap='hot')
im 1,1].set_title('Occlusion Sensitivity Map')
axes[1,1].axis('off')
axes[=axes[1,1])
plt.colorbar(im, ax
plt.tight_layout()=150, bbox_inches='tight')
plt.savefig(save_path, dpi
plt.close()
print(f" ✅ Saved similarity occlusion test to {save_path}")
except Exception as e:
print(f" ❌ Error in similarity occlusion test for pair {pair_idx}: {e}")
import traceback
traceback.print_exc()
print(f"Completed occlusion sensitivity test for epoch {epoch}")
model.train()
bsl_loss.train()
# Test occlusion sensitivity function
print("Testing occlusion sensitivity function...")
= Path("data/test_occlusion")
test_occlusion_dir =True, exist_ok=True)
test_occlusion_dir.mkdir(parents
run_occlusion_sensitivity_test(=-1, save_dir=test_occlusion_dir
model, bsl_loss, tiny_test_dataset, device, epoch
)
print("✅ Occlusion sensitivity test completed!")
# Set up full training configuration with verification
def setup_full_training():
print("Setting up full training configuration...")
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device print(f"Using device: {device}")
# Set seed for reproducibility
42)
set_seed(print("✅ Seed set for reproducibility")
# Model
= len(train_dataset.label_to_idx)
num_classes = MegaDescriptorWithBSL(num_classes).to(device)
model print(f"✅ Model created with {model.embedding_size} embedding size for {num_classes} classes")
# Loss function with ArcFace + Background suppression
= BackgroundSuppressionArcFaceLoss(
bsl_loss =num_classes,
num_classes=model.embedding_size,
embedding_size=0.5,
margin=64,
scale=1.0, # ArcFace weight
alpha=0.5 # Background suppression weight
beta
).to(device)print(f"✅ BSL loss created (alpha={1.0}, beta={0.5})")
# Optimizer for both model and ArcFace parameters
= itertools.chain(model.parameters(), bsl_loss.parameters())
params = torch.optim.AdamW(params, lr=1e-4, weight_decay=1e-4)
optimizer print(f"✅ Optimizer created with lr=1e-4")
# Learning rate scheduler
= torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
scheduler print("✅ Cosine annealing scheduler created")
# Data loaders
= DataLoader(
train_loader
train_dataset, =16,
batch_size=True,
shuffle=2,
num_workers=True if device.type == 'cuda' else False
pin_memory
)
= DataLoader(
test_loader
test_dataset, =16,
batch_size=False,
shuffle=2,
num_workers=True if device.type == 'cuda' else False
pin_memory
)print(f"✅ Data loaders created: {len(train_loader)} train batches, {len(test_loader)} test batches")
# Create directories for saving
"data", exist_ok=True)
os.makedirs(= Path("data/occlusion_maps")
occlusion_dir =True, exist_ok=True)
occlusion_dir.mkdir(parentsprint(f"✅ Directories created: data/, {occlusion_dir}")
# Test one forward pass with real data
print("Testing forward pass with real data...")
with torch.no_grad():
for batch in train_loader:
= batch
images, labels, masks = images.to(device), labels.to(device), masks.to(device)
images, labels, masks
= model(images)
embeddings = model.get_patch_features()
patch_features = bsl_loss(embeddings, labels, masks, patch_features)
loss, arcface_loss, bg_loss
print(f"✅ Forward pass successful:")
print(f" Batch shape: {images.shape}")
print(f" Embeddings: {embeddings.shape}")
print(f" Patch features: {patch_features.shape if patch_features is not None else None}")
print(f" Losses: total={loss.item():.4f}, arcface={arcface_loss.item():.4f}, bg={bg_loss.item():.4f}")
break
return model, bsl_loss, optimizer, scheduler, train_loader, test_loader, device, occlusion_dir
# Setup training with verification
= setup_full_training() model, bsl_loss, optimizer, scheduler, train_loader, test_loader, device, occlusion_dir
# Test one complete epoch to verify everything works
print("=" * 60)
print("TESTING ONE COMPLETE EPOCH")
print("=" * 60)
# Initialize best accuracy for testing
= 0
best_acc
# Create a subset for testing
= NewtDataset(df_train.head(64), dataset_path, transform=transform_train, return_mask=True)
test_train_dataset = DataLoader(test_train_dataset, batch_size=8, shuffle=True, num_workers=0)
test_train_loader
= NewtDataset(df_test.head(32), dataset_path, transform=transform_test, return_mask=True)
test_test_dataset = DataLoader(test_test_dataset, batch_size=8, shuffle=False, num_workers=0)
test_test_loader
print("Running test training epoch...")
= train_epoch(
train_loss, arcface_loss, bg_loss, train_acc =0
model, test_train_loader, bsl_loss, optimizer, device, epoch
)
print("Running test evaluation...")
= evaluate(model, test_test_loader, bsl_loss, device)
test_acc
print("Testing scheduler step...")
= optimizer.param_groups[0]['lr']
old_lr
scheduler.step()= optimizer.param_groups[0]['lr']
new_lr
print(f"✅ Complete epoch test passed!")
print(f" Train Loss: {train_loss:.4f} (ArcFace: {arcface_loss:.4f}, BG: {bg_loss:.4f})")
print(f" Train Acc: {train_acc:.2f}%")
print(f" Test Acc: {test_acc:.2f}% {'🌟 BEST!' if test_acc > best_acc else ''}")
print(f" Best Acc: {best_acc:.2f}%")
print(f" LR: {old_lr:.2e} → {new_lr:.2e}")
# Test training visualization with actual results
def test_training_visualization_with_actual_results(train_loss, arcface_loss, bg_loss, train_acc, test_acc, lr):
print("Testing training visualization with actual results...")
# Use actual training results
= [0] # Just one epoch for testing
epochs = [train_loss]
train_losses = [arcface_loss]
arcface_losses = [bg_loss]
bg_losses = [train_acc]
train_accs = [test_acc]
test_accs = [lr]
lrs
= plt.subplots(2, 2, figsize=(12, 8))
fig, axes
# Loss curves
0,0].plot(epochs, train_losses, label='Total Loss', marker='o', markersize=8)
axes[0,0].plot(epochs, arcface_losses, label='ArcFace Loss', marker='s', markersize=8)
axes[0,0].plot(epochs, bg_losses, label='Background Loss', marker='^', markersize=8)
axes[0,0].set_title('Training Losses (Actual Results)')
axes[0,0].legend()
axes[0,0].grid(True)
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].set_xlim(-0.1, 0.1)
axes[
# Accuracy curves
0,1].plot(epochs, train_accs, label='Train Acc', marker='o', markersize=8)
axes[0,1].plot(epochs, test_accs, label='Test Acc', marker='s', markersize=8)
axes[0,1].set_title('Accuracy (Actual Results)')
axes[0,1].legend()
axes[0,1].grid(True)
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Accuracy (%)')
axes[0,1].set_xlim(-0.1, 0.1)
axes[
# Loss breakdown
1,0].plot(epochs, arcface_losses, label='ArcFace', marker='o', markersize=8)
axes[1,0].plot(epochs, bg_losses, label='Background Suppression', marker='s', markersize=8)
axes[1,0].set_title('Loss Components (Actual Results)')
axes[1,0].legend()
axes[1,0].grid(True)
axes[1,0].set_xlabel('Epoch')
axes[1,0].set_ylabel('Loss')
axes[1,0].set_xlim(-0.1, 0.1)
axes[
# Learning rate
1,1].plot(epochs, lrs, marker='o', markersize=8)
axes[1,1].set_title('Learning Rate (Actual)')
axes[1,1].grid(True)
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Learning Rate')
axes[1,1].set_xlim(-0.1, 0.1)
axes[
# Add actual values as text - fix the max() calls
= max(train_loss, arcface_loss, bg_loss) # Compare actual values, not lists
max_loss = max(train_acc, test_acc) # Compare actual values, not lists
max_acc
0,0].text(0, max_loss, f'Total: {train_loss:.4f}\nArcFace: {arcface_loss:.4f}\nBG: {bg_loss:.4f}',
axes[=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
bbox0,1].text(0, max_acc, f'Train: {train_acc:.2f}%\nTest: {test_acc:.2f}%',
axes[=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
bbox
plt.tight_layout()'data/actual_training_test_results.png', dpi=150, bbox_inches='tight')
plt.savefig(
plt.show()
print("✅ Training visualization with actual results completed!")
print(f" Results saved to: data/actual_training_test_results.png")
print(f" Train Loss: {train_loss:.4f} (ArcFace: {arcface_loss:.4f}, BG: {bg_loss:.4f})")
print(f" Accuracies: Train {train_acc:.2f}%, Test {test_acc:.2f}%")
print(f" Learning Rate: {lr:.2e}")
# Test with actual results from the training epoch
print("\n" + "="*60)
print("TESTING VISUALIZATION WITH ACTUAL RESULTS")
print("="*60)
test_training_visualization_with_actual_results(=train_loss,
train_loss=arcface_loss,
arcface_loss=bg_loss,
bg_loss=train_acc,
train_acc=test_acc,
test_acc=new_lr # Use the learning rate after scheduler step
lr )
# Main training function with comprehensive testing
def train_newt_reid_with_bsl():
print("🚀 STARTING COMPREHENSIVE TRAINING")
print("=" * 70)
= 5
num_epochs = 0
best_acc = []
train_history
# Initial sanity check
print("Performing initial sanity checks...")
with torch.no_grad():
for batch in train_loader:
= batch
images, labels, masks = images.to(device), labels.to(device), masks.to(device)
images, labels, masks = model(images)
embeddings = model.get_patch_features()
patch_features = bsl_loss(embeddings, labels, masks, patch_features)
loss, arcface_loss, bg_loss print(f"✅ Initial forward pass: Loss={loss.item():.4f}")
break
print("Starting training loop...")
for epoch in range(num_epochs):
print(f"\n{'='*20} EPOCH {epoch+1}/{num_epochs} {'='*20}")
# Training
print("Training...")
= train_epoch(
train_loss, arcface_loss, bg_loss, train_acc
model, train_loader, bsl_loss, optimizer, device, epoch
)
# Evaluation
print("Evaluating...")
= evaluate(model, test_loader, bsl_loss, device)
test_acc
# Learning rate scheduling
= optimizer.param_groups[0]['lr']
old_lr
scheduler.step()= optimizer.param_groups[0]['lr']
current_lr
# Save training history
= {
epoch_data 'epoch': epoch,
'train_loss': train_loss,
'arcface_loss': arcface_loss,
'bg_loss': bg_loss,
'train_acc': train_acc,
'test_acc': test_acc,
'lr': current_lr
}
train_history.append(epoch_data)
# Occlusion sensitivity testing every 5 epochs
if epoch % 5 == 0:
print(f"Running occlusion sensitivity test...")
run_occlusion_sensitivity_test(model, bsl_loss, test_dataset, device, epoch, occlusion_dir)
# Save best model
= test_acc > best_acc
is_best if is_best:
= test_acc
best_acc = {
checkpoint 'epoch': epoch,
'model_state_dict': model.state_dict(),
'loss_state_dict': bsl_loss.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'embedding_size': model.embedding_size,
'num_classes': len(train_dataset.label_to_idx),
'train_history': train_history
}'data/best_newt_reid_bsl_model.pth')
torch.save(checkpoint, print(f"💾 NEW BEST MODEL SAVED! Accuracy: {best_acc:.2f}%")
# Print epoch summary
print(f"\n📊 EPOCH {epoch} SUMMARY:")
print(f" Train Loss: {train_loss:.4f} (ArcFace: {arcface_loss:.4f}, BG: {bg_loss:.4f})")
print(f" Train Acc: {train_acc:.2f}%")
print(f" Test Acc: {test_acc:.2f}% {'🌟 BEST!' if is_best else ''}")
print(f" Best Acc: {best_acc:.2f}%")
print(f" LR: {old_lr:.2e} → {current_lr:.2e}")
# Visualization every 5 epochs
if epoch > 0 and epoch % 5 == 0:
print("Creating training visualizations...")
= [h['epoch'] for h in train_history]
epochs = [h['train_loss'] for h in train_history]
train_losses = [h['arcface_loss'] for h in train_history]
arcface_losses = [h['bg_loss'] for h in train_history]
bg_losses = [h['train_acc'] for h in train_history]
train_accs = [h['test_acc'] for h in train_history]
test_accs = [h['lr'] for h in train_history]
lrs
= plt.subplots(2, 2, figsize=(12, 8))
fig, axes
# Loss curves
0,0].plot(epochs, train_losses, label='Total Loss', marker='o')
axes[0,0].plot(epochs, arcface_losses, label='ArcFace Loss', marker='s')
axes[0,0].plot(epochs, bg_losses, label='Background Loss', marker='^')
axes[0,0].set_title('Training Losses')
axes[0,0].legend()
axes[0,0].grid(True)
axes[
# Accuracy curves
0,1].plot(epochs, train_accs, label='Train Acc', marker='o')
axes[0,1].plot(epochs, test_accs, label='Test Acc', marker='s')
axes[0,1].set_title('Accuracy')
axes[0,1].legend()
axes[0,1].grid(True)
axes[
# Loss breakdown
1,0].plot(epochs, arcface_losses, label='ArcFace', marker='o')
axes[1,0].plot(epochs, bg_losses, label='Background Suppression', marker='s')
axes[1,0].set_title('Loss Components')
axes[1,0].legend()
axes[1,0].grid(True)
axes[
# Learning rate
1,1].plot(epochs, lrs, marker='o')
axes[1,1].set_title('Learning Rate')
axes[1,1].set_yscale('log')
axes[1,1].grid(True)
axes[
plt.tight_layout()f'data/training_progress_epoch_{epoch}.png', dpi=150, bbox_inches='tight')
plt.savefig(
plt.show()
print(f"\n🎉 TRAINING COMPLETED!")
print(f"📈 Best test accuracy: {best_acc:.2f}%")
print(f"💾 Model saved to: data/best_newt_reid_bsl_model.pth")
print(f"🔍 Occlusion maps saved to: {occlusion_dir}")
return model, bsl_loss, train_history
print("✅ All tests passed! Ready to start main training...")
# Final pre-training verification
print("FINAL PRE-TRAINING VERIFICATION")
print("=" * 50)
# Check GPU memory if available
if torch.cuda.is_available():
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"GPU Memory Free: {torch.cuda.memory_reserved(0) / 1e9:.1f} GB")
# Check dataset sizes
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {len(train_dataset.label_to_idx)}")
print(f"Batches per epoch: {len(train_loader)}")
# Final forward pass test
print("Final forward pass test...")
with torch.no_grad():
= next(iter(train_loader))
test_batch = test_batch
images, labels, masks = images.to(device), labels.to(device), masks.to(device)
images, labels, masks
= model(images)
embeddings = model.get_patch_features()
patch_features = bsl_loss(embeddings, labels, masks, patch_features)
loss, arcface_loss, bg_loss
print(f"✅ Final test successful!")
print(f" Batch size: {images.shape[0]}")
print(f" Total loss: {loss.item():.4f}")
print(f" Memory usage: {torch.cuda.memory_allocated(0) / 1e6:.1f} MB" if torch.cuda.is_available() else "CPU mode")
print("\n🚀 READY TO START TRAINING!")
# Start the actual training!
= train_newt_reid_with_bsl() trained_model, trained_bsl_loss, history