Pytorch圖像檢索實踐
隨著電子商務和在線網(wǎng)站的出現(xiàn),圖像檢索在我們的日常生活中的應用一直在增加。
亞馬遜、阿里巴巴、Myntra等公司一直在大量利用圖像檢索技術。當然,只有當通常的信息檢索技術失敗時,圖像檢索才會開始工作。
背景
圖像檢索的基本本質是根據(jù)查詢圖像的特征從集合或數(shù)據(jù)庫中查找圖像。
大多數(shù)情況下,這種特征是圖像之間簡單的視覺相似性。在一個復雜的問題中,這種特征可能是兩幅圖像在風格上的相似性,甚至是互補性。
由于原始形式的圖像不會在基于像素的數(shù)據(jù)中反映這些特征,因此我們需要將這些像素數(shù)據(jù)轉換為一個潛空間,在該空間中,圖像的表示將反映這些特征。
一般來說,在潛空間中,任何兩個相似的圖像都會相互靠近,而不同的圖像則會相隔很遠。這是我們用來訓練我們的模型的基本管理規(guī)則。一旦我們這樣做,檢索部分只需搜索潛在空間,在給定查詢圖像表示的潛在空間中拾取最近的圖像。大多數(shù)情況下,它是在最近鄰搜索的幫助下完成的。
因此,我們可以將我們的方法分為兩部分:
1. 圖像表現(xiàn)
2. 搜索
我們將在Oxford 102 Flowers數(shù)據(jù)集上解決這兩個部分。
圖像表現(xiàn)
我們將使用一種叫做暹羅模型的東西,它本身并不是一種全新的模型,而是一種訓練模型的技術。大多數(shù)情況下,這是與triplet loss一起使用的。這個技術的基本組成部分是三元組。
三元組是3個獨立的數(shù)據(jù)樣本,比如A(錨點),B(陽性)和C(陰性);其中A和B相似或具有相似的特征(可能是同一類),而C與A和B都不相似。這三個樣本共同構成了訓練數(shù)據(jù)的一個單元——三元組。
注:任何圖像檢索任務的90%都體現(xiàn)在暹羅網(wǎng)絡、triplet loss和三元組的創(chuàng)建中。如果你成功地完成了這些,那么整個努力的成功或多或少是有保證的。
首先,我們將創(chuàng)建管道的這個組件——數(shù)據(jù)。下面我們將在PyTorch中創(chuàng)建一個自定義數(shù)據(jù)集和數(shù)據(jù)加載器,它將從數(shù)據(jù)集中生成三元組。
class TripletData(Dataset):
def __init__(self, path, transforms, split="train"):
self.path = path
self.split = split # train or valid
self.cats = 102 # number of categories
self.transforms = transforms
def __getitem__(self, idx):
# our positive class for the triplet
idx = str(idx%self.cats + 1)
# choosing our pair of positive images (im1, im2)
positives = os.listdir(os.path.join(self.path, idx))
im1, im2 = random.sample(positives, 2)
# choosing a negative class and negative image (im3)
negative_cats = [str(x+1) for x in range(self.cats)]
negative_cats.remove(idx)
negative_cat = str(random.choice(negative_cats))
negatives = os.listdir(os.path.join(self.path, negative_cat))
im3 = random.choice(negatives)
im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
im1 = self.transforms(Image.open(im1))
im2 = self.transforms(Image.open(im2))
im3 = self.transforms(Image.open(im3))
return [im1, im2, im3]
# we'll put some value that we want since there can be far too many triplets possible
# multiples of the number of images/ number of categories is a good choice
def __len__(self):
return self.cats*8
# Transforms
train_transforms = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)
現(xiàn)在我們有了數(shù)據(jù),讓我們轉到暹羅網(wǎng)絡。
暹羅網(wǎng)絡給人的印象是2個或3個模型,但是它本身是一個單一的模型。所有這些模型共享權重,即只有一個模型。
如前所述,將整個體系結構結合在一起的關鍵因素是triplet loss。triplet loss產生了一個目標函數(shù),該函數(shù)迫使相似輸入對(錨點和正)之間的距離小于不同輸入對(錨點和負)之間的距離,并限定一定的閾值。
下面我們來看看triplet loss以及訓練管道實現(xiàn)。
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def calc_euclidean(self, x1, x2):
return (x1 - x2).pow(2).sum(1)
# Distances in embedding space is calculated in euclidean
def forward(self, anchor, positive, negative):
distance_positive = self.calc_euclidean(anchor, positive)
distance_negative = self.calc_euclidean(anchor, negative)
losses = torch.relu(distance_positive - distance_negative + self.margin)
return losses.mean()
device = 'cuda'
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
# Training
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
for data in tqdm(train_loader):
optimizer.zero_grad()
x1,x2,x3 = data
e1 = model(x1.to(device))
e2 = model(x2.to(device))
e3 = model(x3.to(device))
loss = triplet_loss(e1,e2,e3)
epoch_loss += loss
loss.backward()
optimizer.step()
print("Train Loss: {}".format(epoch_loss.item()))
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def calc_euclidean(self, x1, x2):
return (x1 - x2).pow(2).sum(1)
# Distances in embedding space is calculated in euclidean
def forward(self, anchor, positive, negative):
distance_positive = self.calc_euclidean(anchor, positive)
distance_negative = self.calc_euclidean(anchor, negative)
losses = torch.relu(distance_positive - distance_negative + self.margin)
return losses.mean()
device = 'cuda'
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
# Training
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
for data in tqdm(train_loader):
optimizer.zero_grad()
x1,x2,x3 = data
e1 = model(x1.to(device))
e2 = model(x2.to(device))
e3 = model(x3.to(device))
loss = triplet_loss(e1,e2,e3)
epoch_loss += loss
loss.backward()
optimizer.step()
print("Train Loss: {}".format(epoch_loss.item()))
到目前為止,我們的模型已經(jīng)經(jīng)過訓練,可以將圖像轉換為一個嵌入空間。接下來,我們進入搜索部分。
搜索
我們可以很容易地使用Scikit Learn提供的最近鄰搜索。我們將探索新的更好的東西,而不是走簡單的路線。
我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的圖像,這種速度上的差異會變得更加明顯。
下面我們將演示如何在給定查詢圖像時,在存儲的圖像表示中搜索最近的圖像。
#!pip install faiss-gpu
import faiss
faiss_index = faiss.IndexFlatL2(1000) # build the index
# storing the image representations
im_indices = []
with torch.no_grad():
for f in glob.glob(os.path.join(PATH_TRAIN, '*')):
im = Image.open(f)
im = im.resize((224,224))
im = torch.tensor([val_transforms(im).numpy()]).cuda()
preds = model(im)
preds = np.array([preds[0].cpu().numpy()])
faiss_index.add(preds) #add the representation to index
im_indices.append(f) #store the image name to find it later on
# Retrieval with a query image
with torch.no_grad():
for f in os.listdir(PATH_TEST):
# query/test image
im = Image.open(os.path.join(PATH_TEST,f))
im = im.resize((224,224))
im = torch.tensor([val_transforms(im).numpy()]).cuda()
test_embed = model(im).cpu().numpy()
_, I = faiss_index.search(test_embed, 5)
print("Retrieved Image: {}".format(im_indices[I[0][0]]))
這涵蓋了基于現(xiàn)代深度學習的圖像檢索,但不會使其變得太復雜。大多數(shù)檢索問題都可以通過這個基本管道解決。
原文標題 : Pytorch圖像檢索實踐
請輸入評論內容...
請輸入評論/評論長度6~500個字
最新活動更多
-
即日-11.13立即報名>>> 【在線會議】多物理場仿真助跑新能源汽車
-
11月28日立即報名>>> 2024工程師系列—工業(yè)電子技術在線會議
-
12月19日立即報名>> 【線下會議】OFweek 2024(第九屆)物聯(lián)網(wǎng)產業(yè)大會
-
即日-12.26火熱報名中>> OFweek2024中國智造CIO在線峰會
-
即日-2025.8.1立即下載>> 《2024智能制造產業(yè)高端化、智能化、綠色化發(fā)展藍皮書》
-
精彩回顧立即查看>> 【限時免費下載】TE暖通空調系統(tǒng)高效可靠的組件解決方案
推薦專題
- 高級軟件工程師 廣東省/深圳市
- 自動化高級工程師 廣東省/深圳市
- 光器件研發(fā)工程師 福建省/福州市
- 銷售總監(jiān)(光器件) 北京市/海淀區(qū)
- 激光器高級銷售經(jīng)理 上海市/虹口區(qū)
- 光器件物理工程師 北京市/海淀區(qū)
- 激光研發(fā)工程師 北京市/昌平區(qū)
- 技術專家 廣東省/江門市
- 封裝工程師 北京市/海淀區(qū)
- 結構工程師 廣東省/深圳市