您现在的位置是:主页 > news > 网站建设完整教程视频教程/移动端排名优化软件
网站建设完整教程视频教程/移动端排名优化软件
admin2025/6/23 11:49:34【news】
简介网站建设完整教程视频教程,移动端排名优化软件,家居网站建设方案,餐饮服务怎么做网络推广以下对接口或者函数的描述从简,建议参照官方文档(记得左上角选择对应版本!) PyTorch的torchvision下内置了一些常用数据集的接口,比如MNIST、CIFAR、COCO等,可用以下方式调用: # MNIST为例 from torchvision import…
以下对接口或者函数的描述从简,建议参照官方文档(记得左上角选择对应版本!)
PyTorch的torchvision
下内置了一些常用数据集的接口,比如MNIST
、CIFAR
、COCO
等,可用以下方式调用:
# MNIST为例
from torchvision import datasets
dataset = datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)
root
:数据集根路径
train
:用于区分训练或者测试
transform
:预处理操作(数据增强、归一化等)
然后使用torch.utils.data.DataLoader
对dataset
进行封装(封装为Tensor),就可以作为模型训练的输入了。
但是,对于很多公共数据集以及私人数据集,是没有官方接口的,所以就需要用户去实现数据的读取。可以使用torchvision.datasets.ImageFolder
接口实现数据导入,但是要注意,这个接口假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名即为类名。例如,可以将猫狗大战的数据集中,可以将cat和dog的图像分别存储到两个文件夹下面,才可以使用这个接口,不过这也很容易操作:
# 在命令行里面运行
# 首先要进入train文件夹,该文件夹下有25000的训练数据,Windows下面可以:
mkdir dog cat
move dog.* dog/
move cat.* cat/
ImageFolder
接口使用与MNIST
等接口类似,可以查阅官方文档:
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None)
又但是,很多数据集(例如猫狗原本维护方式就是train
文件夹下面猫狗数据集都包含)也不像上面这也是分类别进行维护的,那就需要用户自定义接口来完成数据读取。PyTorch中自定义的接口需要继承自Dataset类,同时主要需要实现两个类的方法:
__getitem__
:返回一个样本__len__
:返回样本总数
同样以猫狗数据集为例,简要接口可以实现如下,至于不简要的,建议看源码依样画葫芦也很容易实现 >_<
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as Tclass DogCat(data.Dataset):def __init__(self, root, transforms=None, train=True):imgs = [os.path.join(root, img) for img in os.listdir(root)]imgs_num = len(imgs)# 划分训练、验证集,训练:验证 = 4:1if train:self.imgs = imgs[:int(0.8 * imgs_num)] # 训练集else:self.imgs = imgs[int(0.8 * imgs_num):] # 验证集if transforms is None:if train:self.transforms = T.Compose([T.Resize(224),T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])else:self.transforms = T.Compose([T.Resize(256),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])def __getitem__(self, index):img_path = self.imgs[index] # './catdog/train/cat.0.jpg'label = 1 if 'dog' in img_path.split('/')[-1] else 0 data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):return len(self.imgs)
有了数据读取接口之后,使用torch.utils.data.DataLoader
对dataset
进行封装就可以作为模型输入了:
# 训练集
train_data = DogCat(data_dir, train=True)
train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=4)# 验证集
val_data = DogCat(data_dir, train=False)
val_dataloader = DataLoader(val_data, batch_size, shuffle=False, num_workers=4)
以上train_dataloader
和val_dataloader
就可以输入模型进行迭代了,建议使用迁移学习(偷懒)。但是预训练的模型拿过来是不能直接使用的,需要修改最后的全连接层输出类别数为猫狗例子中的classes = 2
,然后选择是对模型进行微调还是直接将模型作为特征提取器(设置相应层是否需要参数更新),代码如下:
# 获取resnet18预训练模型
model = models.resnet18(pretrained=True)def init_model(model, num_classes, feature_extract):# 如果不进行微调,则设置相应层的梯度设置if feature_extract:for param in model.parameters():param.requires_grad = False# 修改全连接层输出类别为2model.fc = torch.nn.Linear(model.fc.in_features, num_classes) return modelmodel = init_model(model, num_classes, feature_extract)
现在有了模型,和模型输入数据,还需要损失函数和优化器:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9)
接下来定义训练函数就可以跑起来了:
def train(model, device, train_dataloader, val_dataloader, criterion, optimizer, num_epochs):model.to(device)# trainfor epoch in range(num_epochs):model.train() for idx, (data,label) in enumerate(train_dataloader):data, label = data.to(device), label.to(device)pred = model(data)loss = criterion(pred, label)optimizer.zero_grad() # 梯度清零loss.backward()optimizer.step() # 参数更新if idx % 100 == 0:print("Train Epoch: {}, iteration: {}, Loss: {}".format(epoch, idx, loss.item()))model.eval() total_loss = []correct = 0 with torch.no_grad():for idx, (data, target) in enumerate(val_dataloader):data, target = data.to(device), target.to(device)output = model(data)total_loss.append(criterion(output, target).item())pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum()total_loss = np.mean(total_loss)acc = correct.item() / len(val_dataloader.dataset) * 100 print("Val loss: {}, Accuracy: {}".format(total_loss, acc))
总之遇事不决、遇事不懂,看源码(‾◡◝)
本文猫狗简单代码