首頁 > 軟體

Python Pytorch學習之影象檢索實踐

2022-04-08 19:01:07

隨著電子商務和線上網站的出現,影象檢索在我們的日常生活中的應用一直在增加。

亞馬遜、阿里巴巴、Myntra等公司一直在大量利用影象檢索技術。當然,只有當通常的資訊檢索技術失敗時,影象檢索才會開始工作。

背景

影象檢索的基本本質是根據查詢影象的特徵從集合或資料庫中查詢影象。

大多數情況下,這種特徵是影象之間簡單的視覺相似性。在一個複雜的問題中,這種特徵可能是兩幅影象在風格上的相似性,甚至是互補性。

由於原始形式的影象不會在基於畫素的資料中反映這些特徵,因此我們需要將這些畫素資料轉換為一個潛空間,在該空間中,影象的表示將反映這些特徵。

一般來說,在潛空間中,任何兩個相似的影象都會相互靠近,而不同的影象則會相隔很遠。這是我們用來訓練我們的模型的基本管理規則。一旦我們這樣做,檢索部分只需搜尋潛在空間,在給定查詢影象表示的潛在空間中拾取最近的影象。大多數情況下,它是在最近鄰搜尋的幫助下完成的。

因此,我們可以將我們的方法分為兩部分:

  • 影象表現
  • 搜尋

我們將在Oxford 102 Flowers資料集上解決這兩個部分。

影象表現

我們將使用一種叫做暹羅模型的東西,它本身並不是一種全新的模型,而是一種訓練模型的技術。大多數情況下,這是與triplet loss一起使用的。這個技術的基本組成部分是三元組。

三元組是3個獨立的資料樣本,比如A(錨點),B(陽性)和C(陰性);其中A和B相似或具有相似的特徵(可能是同一類),而C與A和B都不相似。這三個樣本共同構成了訓練資料的一個單元——三元組。

注:任何影象檢索任務的90%都體現在暹羅網路、triplet loss和三元組的建立中。如果你成功地完成了這些,那麼整個努力的成功或多或少是有保證的。

首先,我們將建立管道的這個元件——資料。下面我們將在PyTorch中建立一個自定義資料集和資料載入器,它將從資料集中生成三元組。

class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
 
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
 
        
    def __getitem__(self, idx):
 
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
 
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
 
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
 
        im3 = random.choice(negatives)
 
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
 
        im1 = self.transforms(Image.open(im1))
 
        im2 = self.transforms(Image.open(im2))
 
        im3 = self.transforms(Image.open(im3))
 
        return [im1, im2, im3]
 
    
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

現在我們有了資料,讓我們轉到暹羅網路。

暹羅網路給人的印象是2個或3個模型,但是它本身是一個單一的模型。所有這些模型共用權重,即只有一個模型。

如前所述,將整個體系結構結合在一起的關鍵因素是triplet loss。triplet loss產生了一個目標函數,該函數迫使相似輸入對(錨點和正)之間的距離小於不同輸入對(錨點和負)之間的距離,並限定一定的閾值。

下面我們來看看triplet loss以及訓練管道實現。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    
    for data in tqdm(train_loader):
        
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))
 
    
    
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
 
# Training
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
 
        optimizer.zero_grad()
        
        x1,x2,x3 = data
        
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))

到目前為止,我們的模型已經經過訓練,可以將影象轉換為一個嵌入空間。接下來,我們進入搜尋部分。

搜尋

我們可以很容易地使用Scikit Learn提供的最近鄰搜尋。我們將探索新的更好的東西,而不是走簡單的路線。

我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的影象,這種速度上的差異會變得更加明顯。

下面我們將演示如何在給定查詢影象時,在儲存的影象表示中搜尋最近的影象。

#!pip install faiss-gpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index
 
# storing the image representations
im_indices = []
 
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
 
        
# Retrieval with a query image
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        
        # query/test image
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        test_embed = model(im).cpu().numpy()
        
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

這涵蓋了基於現代深度學習的影象檢索,但不會使其變得太複雜。大多數檢索問題都可以通過這個基本管道解決。

以上就是Python Pytorch學習之影象檢索實踐的詳細內容,更多關於Python Pytorch影象檢索的資料請關注it145.com其它相關文章!


IT145.com E-mail:sddin#qq.com