Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

Case Analysis of Python Pytorch Image Retrieval

2025-03-28 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/01 Report--

This article mainly introduces "Python Pytorch image retrieval case analysis". In the daily operation, I believe that many people have doubts about Python Pytorch image retrieval case analysis. The editor consulted all kinds of materials and sorted out simple and easy-to-use operation methods. I hope it will be helpful to answer the doubts of "Python Pytorch image retrieval case analysis". Next, please follow the editor to study!

Background

The basic essence of image retrieval is to find the image from the collection or database according to the characteristics of the query image.

In most cases, this feature is a simple visual similarity between images. In a complex problem, this feature may be the stylistic similarity or even complementarity of the two images.

Because the original form of the image will not reflect these features in the pixel-based data, we need to convert these pixel data into a latent space in which the representation of the image will reflect these features.

Generally speaking, in the latent space, any two similar images will be close to each other, while different images will be far apart. This is the basic management rule that we use to train our model. Once we do this, the retrieval part only needs to search the potential space and pick up the nearest image in the potential space represented by the given query image. In most cases, it is done with the help of nearest neighbor search.

Therefore, we can divide our approach into two parts:

Image representation

Search

We will address these two parts on the Oxford 102 Flowers dataset.

Image representation

We will use something called the Siamese model, which itself is not a brand new model, but a technique for training models. In most cases, this is used with triplet loss. The basic component of this technique is a triple.

The triple consists of three independent data samples, such as A (anchor), B (positive) and C (negative), in which An and B are similar or have similar characteristics (possibly the same class), while C is not similar to An and B. Together, these three samples constitute a unit of training data-triple.

Note: 90% of any image retrieval task is reflected in the creation of Siamese networks, triplet loss and triples. If you succeed in doing this, the success of the whole effort is more or less guaranteed.

First, we will create this component of the pipeline-- data. Next we will create a custom dataset and data loader in PyTorch that will generate triples from the dataset.

Class TripletData (Dataset): def _ _ init__ (self, path, transforms, split= "train"): self.path = path self.split = split # train or valid self.cats = 102 # number of categories self.transforms = transforms def _ getitem__ (self Idx): # our positive class for the triplet idx = str (idx%self.cats + 1) # choosing our pair of positive images (im1, im2) positives = os.listdir (os.path.join (self.path, idx)) im1, im2 = random.sample (positives 2) # choosing a negative class and negative image (im3) negative_cats = [str (xylene 1) for x in range (self.cats)] negative_cats.remove (idx) negative_cat = str (random.choice (negative_cats)) negatives = os.listdir (os.path.join (self.path, negative_cat) im3 = random.choice (negatives) im1,im2,im3 = os.path.join (self.path, idx) Im1), os.path.join (self.path, idx, im2), os.path.join (self.path, negative_cat, im3) im1 = self.transforms (Image.open (im1)) im2 = self.transforms (Image.open (im2)) im3 = self.transforms (Image.open (im3)) return [im1, im2 Im3] # we'll put some value that we want since there can be far too many triplets possible # multiples of the number of images/ number of categories is a good choice def _ len__ (self): return self.cats*8# Transformstrain_transforms = transforms.Compose ([transforms.Resize ((224224)), transforms.RandomHorizontalFlip (), transforms.ToTensor (), transforms.Normalize ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994) ),]) val_transforms = transforms.Compose ([transforms.Resize ((224,224)), transforms.ToTensor (), transforms.Normalize ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),] # Datasets and Dataloaderstrain_data = TripletData (PATH_TRAIN, train_transforms) val_data = TripletData (PATH_VALID, val_transforms) train_loader = torch.utils.data.DataLoader (dataset = train_data, batch_size=32, shuffle=True Num_workers=4) val_loader = torch.utils.data.DataLoader (dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

Now that we have the data, let's go to the Siamese network.

The Siamese network gives the impression of two or three models, but it itself is a single model. All these models share weights, that is, there is only one model.

As mentioned earlier, the key factor that binds the entire architecture together is triplet loss. Triplet loss produces an objective function that forces the distance between similar input pairs (anchor and positive) to be smaller than the distance between different input pairs (anchor and negative) and limits a certain threshold.

Let's take a look at triplet loss and the implementation of the training pipeline.

Class TripletLoss (nn.Module): def _ init__ (self, margin=1.0): super (TripletLoss, self). _ _ init__ () self.margin = margin def calc_euclidean (self, x1, x2): return (x1-x2). Pow (2) .sum (1) # Distances in embedding space is calculated in euclidean def forward (self, anchor, positive Negative): distance_positive = self.calc_euclidean (anchor, positive) distance_negative = self.calc_euclidean (anchor, negative) losses = torch.relu (distance_positive-distance_negative + self.margin) return losses.mean () device = 'cuda' # Our base modelmodel = models.resnet18 (). Cuda () optimizer = optim.Adam (model.parameters () Lr=0.001) triplet_loss = TripletLoss () # Trainingfor epoch in range (epochs): model.train () epoch_loss = 0.0 for data in tqdm (train_loader): optimizer.zero_grad () x1Legend x2 X3 = data E1 = model (x1.to (device)) e2 = model (x2.to (device)) e3 = model (x3.to (device)) loss = triplet_loss E3) epoch_loss + = loss loss.backward () optimizer.step () print ("Train Loss: {}" .format (epoch_loss.item ()) class TripletLoss (nn.Module): def _ _ init__ (self, margin=1.0): super (TripletLoss, self). _ _ init__ () self.margin = margin def calc_euclidean (self X1, x2): return (x1-x2) .pow (2) .sum (1) # Distances in embedding space is calculated in euclidean def forward (self, anchor, positive, negative): distance_positive = self.calc_euclidean (anchor, positive) distance_negative = self.calc_euclidean (anchor) Negative) losses = torch.relu (distance_positive-distance_negative + self.margin) return losses.mean () device = 'cuda' # Our base modelmodel = models.resnet18 (). Cuda () optimizer = optim.Adam (model.parameters () Lr=0.001) triplet_loss = TripletLoss () # Trainingfor epoch in range (epochs): model.train () epoch_loss = 0.0 for data in tqdm (train_loader): optimizer.zero_grad () x1Legend x2 X3 = data E1 = model (x1.to (device)) e2 = model (x2.to (device)) e3 = model (x3.to (device)) loss = triplet_loss (e1maxie e2 and e3) epoch_loss + = loss loss.backward () optimizer.step () print ("Train Loss: {}" .format (epoch_loss.item ()

So far, our model has been trained to convert the image into an embedded space. Next, let's go to the search section.

Search

We can easily use the nearest neighbor search provided by Scikit Learn. We will explore new and better things instead of taking a simple route.

We will use Faiss. This is much faster than the nearest neighbor, and if we have a large number of images, the difference in speed will become more obvious.

Next we will demonstrate how to search for the nearest image in the stored image representation when given a query image.

#! pip install faiss-gpuimport faiss faiss_index = faiss.IndexFlatL2 (1000) # build the index # storing the image representationsim_indices = [] with torch.no_grad (): for f in glob.glob (os.path.join (PATH_TRAIN '* / *'): im = Image.open (f) im = im.resize ((224224)) im = torch.tensor ([val_transforms (im). Numpy ()]. Cuda () preds = model (im) preds = np.array ([preds [0] .CPU (). Numpy ()]) faiss_index.add (preds) # add the representation To index im_indices.append (f) # store the image name to find it later on # Retrieval with a query imagewith torch.no_grad (): for fin os.listdir (PATH_TEST): # query/test image im = Image.open (os.path.join (PATH_TEST F) im = im.resize ((224224)) im = torch.tensor ([val_transforms (im). Numpy ()]. Cuda () test_embed = model (im). Cpu (). Numpy () _, I = faiss_index.search (test_embed, 5) print ("Retrieved Image: {}" .format (im_indices [I [0] [0])

This covers image retrieval based on modern deep learning, but it will not make it too complicated. Most retrieval problems can be solved through this basic pipeline.

At this point, the study of "Python Pytorch Image Retrieval case Analysis" is over. I hope to be able to solve your doubts. The collocation of theory and practice can better help you learn, go and try it! If you want to continue to learn more related knowledge, please continue to follow the website, the editor will continue to work hard to bring you more practical articles!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report