import os
import numpy as np
import sys
import torch
from gcn_reid.newt import get_newt_dataset, get_cropping_image_dataset
from wildlife_datasets import splits
import timm
import itertools
from torch.optim import SGD
from wildlife_tools.train import ArcFaceLoss, BasicTrainer
from wildlife_tools.train import set_seed
from torchvision import transforms as T
from wildlife_tools.features import DeepFeatures
from wildlife_tools.inference import TopkClassifier
from wildlife_tools.similarity import CosineSimilarity
from transformers import AutoModel
Finetuning
This notebook finetunes MegaDescriptor and MiewID models on our newt dataset to improve their performance.
Download Newt Dataset
= get_newt_dataset()
NewtDataset = "data/newt_dataset"
dataset_path ="mshahoyi/barhill-newts-segmented", download_path=dataset_path)
NewtDataset._download(dataset_name= NewtDataset(dataset_path)
dataset dataset.df.head()
Create Data Splits
Create Train/Test Split
def create_train_test_split(df, split_ratio=0.5):
= splits.DisjointSetSplit(split_ratio)
disjoint_splitter for idx_train, idx_test in disjoint_splitter.split(df):
= df.loc[idx_train], df.loc[idx_test]
df_train, df_test
splits.analyze_split(df, idx_train, idx_test)return df_train, df_test
= create_train_test_split(dataset.df, split_ratio=0.5)
df_train, df_test = create_train_test_split(df_test, split_ratio=0.5)
df_test, df_val
print(f"Train: {len(df_train)}, Test: {len(df_test)}, Validation: {len(df_val)}")
Closed Set Split (for database and query sets)
def create_database_query_split(df, split_ratio=0.9):
= splits.ClosedSetSplit(split_ratio)
splitter for idx_database, idx_query in splitter.split(df):
= df.loc[idx_database], df.loc[idx_query]
df_database, df_query
splits.analyze_split(df, idx_database, idx_query)return df_database, df_query
= create_database_query_split(df_test, split_ratio=0.9)
df_test_database, df_test_query print(f"Test Database: {len(df_test_database)}, Test Query: {len(df_test_query)}\n\n\n")
= create_database_query_split(df_val, split_ratio=0.9)
df_val_database, df_val_query print(f"Validation Database: {len(df_val_database)}, Validation Query: {len(df_val_query)}\n\n\n")
= create_database_query_split(df_train, split_ratio=0.9)
df_train_database, df_train_query print(f"Train Database: {len(df_train_database)}, Train Query: {len(df_train_query)}\n\n\n")
Train MegaDescriptor
# Download MegaDescriptor-T backbone from HuggingFace Hub
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True).to(device) backbone
= get_cropping_image_dataset()
CroppingImageDataset = T.Compose([
transform 224, 224]),
T.Resize([
T.ToTensor(), =(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
T.Normalize(mean
])
= CroppingImageDataset(df_train, root=dataset_path, transform=transform, crop_out=True) train_dataset
def evaluate(model, df_query, df_database, data_root, transform, crop_out, batch_size, num_workers, device):
# Calculate retrieval results
= DeepFeatures(model,
extractor =device,
device=batch_size,
batch_size=num_workers)
num_workers
print("Extracting features for query set")
= CroppingImageDataset(df_query, root=data_root, transform=transform, crop_out=crop_out)
dataset_query = extractor(dataset_query)
query
print("Extracting features for database set")
= CroppingImageDataset(df_database, root=data_root, transform=transform, crop_out=crop_out)
dataset_database = extractor(dataset_database)
database
= CosineSimilarity()
similarity_function = similarity_function(query, database)
similarity = TopkClassifier(k=5, database_labels=dataset_database.labels_string, return_all=True)
top_5_classifier
= top_5_classifier(similarity)
predictions_top_5, scores_top_5, _ = np.mean(dataset_query.labels_string == predictions_top_5[:, 0])
accuracy_top_1 = np.mean(np.any(predictions_top_5 == dataset_query.labels_string[:, np.newaxis], axis=1))
accuracy_top_5
return dict(accuracy_top_1=accuracy_top_1, accuracy_top_5=accuracy_top_5)
= T.Compose([
val_transform 224, 224]),
T.Resize([
T.ToTensor(), =(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
T.Normalize(mean
])
=backbone,
evaluate(model=df_val_query,
df_query=df_val_database,
df_database=dataset_path,
data_root=val_transform,
transform=True,
crop_out=32, num_workers=4, device=device) batch_size
class EvaluationCb:
def __init__(self, df_query, df_database, root, transform, crop_out):
self.df_query = df_query
self.df_database = df_database
self.root = root
self.transform = transform
self.crop_out = crop_out
self.history = []
def __call__(self, trainer, epoch_data):
= evaluate(model=trainer.model,
eval_results =self.df_query,
df_query=self.df_database,
df_database=self.root,
data_root=self.transform,
transform=self.crop_out,
crop_out=trainer.batch_size,
batch_size=trainer.num_workers,
num_workers=trainer.device)
device
epoch_data.update(eval_results)self.history.append(epoch_data)
print(f"Accuracy top 1: {eval_results['accuracy_top_1']}, Accuracy top 5: {eval_results['accuracy_top_5']}\n\n")
= trainer.model.to(trainer.device) trainer.model
= EvaluationCb(df_val_query,
epoch_callback
df_val_database,
dataset_path,
val_transform, =True) crop_out
# Arcface loss - needs backbone output size and number of classes.
= ArcFaceLoss(
objective =train_dataset.num_classes,
num_classes=768,
embedding_size=0.5,
margin=64
scale
)
# Optimize parameters in backbone and in objective using single optimizer.
= itertools.chain(backbone.parameters(), objective.parameters())
params = SGD(params=params, lr=0.001, momentum=0.9)
optimizer
0)
set_seed(= BasicTrainer(
trainer =train_dataset,
dataset=backbone,
model=objective,
objective=optimizer,
optimizer=5,
epochs=device,
device=4,
num_workers=epoch_callback
epoch_callback
)
trainer.train()
epoch_callback.history
MegaDescriptor on Test Set
= evaluate(model=backbone,
mega_descriptor_results =df_test_query,
df_query=df_test_database,
df_database=dataset_path,
data_root=val_transform,
transform=True,
crop_out=32,
batch_size=4,
num_workers=device)
device
print("MegaDescriptor Results:", mega_descriptor_results)
Train MiewID
= AutoModel.from_pretrained("conservationxlabs/miewid-msv2", trust_remote_code=True).to(device) miew_id_model
= evaluate(model=miew_id_model,
miew_id_results =df_val_query,
df_query=df_val_database,
df_database=dataset_path,
data_root=val_transform,
transform=True,
crop_out=32, num_workers=4, device=device)
batch_size
print("MiewID Results before finetuning:", miew_id_results)
= EvaluationCb(df_val_query,
miew_id_epoch_callback
df_val_database,
dataset_path,
val_transform, =True) crop_out
# Optimize parameters in backbone and in objective using single optimizer.
= itertools.chain(backbone.parameters(), objective.parameters())
params = SGD(params=params, lr=0.001, momentum=0.9)
optimizer
0)
set_seed(= BasicTrainer(
trainer =train_dataset,
dataset=backbone,
model=objective,
objective=optimizer,
optimizer=5,
epochs=device,
device=4,
num_workers=epoch_callback
epoch_callback
)
trainer.train()
epoch_callback.history
import nbdev; nbdev.nbdev_export()