Commit 6496262f authored by lishen's avatar lishen

[fix]

parent 0b240082
...@@ -4,9 +4,11 @@ from PIL import Image ...@@ -4,9 +4,11 @@ from PIL import Image
import torch.utils.data as data import torch.utils.data as data
from torchvision import datasets, transforms from torchvision import datasets, transforms
def My_loader(path): def My_loader(path):
return PIL.Image.open(path).convert('RGB') return PIL.Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset): class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt_dir, image_path, transform=None, target_transform=None, loader=My_loader): def __init__(self, txt_dir, image_path, transform=None, target_transform=None, loader=My_loader):
...@@ -15,8 +17,14 @@ class MyDataset(torch.utils.data.Dataset): ...@@ -15,8 +17,14 @@ class MyDataset(torch.utils.data.Dataset):
for line in data_txt: for line in data_txt:
line = line.strip() line = line.strip()
words = line.split(' ') words = line.split(' ')
print(words) p = ''
imgs.append((words[0], int(words[1].strip()))) for i, word in enumerate(words):
if i < len(words) - 1:
if i > 0:
p += ' '
p += word
imgs.append((p, int(words[-1].strip())))
self.imgs = imgs self.imgs = imgs
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -32,7 +40,7 @@ class MyDataset(torch.utils.data.Dataset): ...@@ -32,7 +40,7 @@ class MyDataset(torch.utils.data.Dataset):
# label = list(map(int, label)) # label = list(map(int, label))
# print label # print label
# print type(label) # print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/")) # img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img = self.loader(self.image_path + img_name) img = self.loader(self.image_path + img_name)
# print img # print img
...@@ -46,6 +54,7 @@ class MyDataset(torch.utils.data.Dataset): ...@@ -46,6 +54,7 @@ class MyDataset(torch.utils.data.Dataset):
# if the label is the single-label it can be the int # if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor # if the multilabel can be the list to torch.tensor
def load_data(image_path, train_dir, test_dir, batch_size): def load_data(image_path, train_dir, test_dir, batch_size):
normalize = transforms.Normalize(mean=[0.5457954, 0.44430383, 0.34424934], normalize = transforms.Normalize(mean=[0.5457954, 0.44430383, 0.34424934],
std=[0.23273608, 0.24383051, 0.24237761]) std=[0.23273608, 0.24383051, 0.24237761])
...@@ -68,6 +77,6 @@ def load_data(image_path, train_dir, test_dir, batch_size): ...@@ -68,6 +77,6 @@ def load_data(image_path, train_dir, test_dir, batch_size):
]) ])
train_dataset = MyDataset(txt_dir=train_dir, image_path=image_path, transform=train_transforms) train_dataset = MyDataset(txt_dir=train_dir, image_path=image_path, transform=train_transforms)
test_dataset = MyDataset(txt_dir=test_dir, image_path=image_path, transform=test_transforms) test_dataset = MyDataset(txt_dir=test_dir, image_path=image_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size//2, shuffle=False, num_workers=0) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size // 2, shuffle=False, num_workers=0)
return train_dataset, train_loader, test_dataset, test_loader return train_dataset, train_loader, test_dataset, test_loader
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment