Коммит 94478dbc создал по автору Yaroslav's avatar Yaroslav Зафиксировано автором GitHub
Просмотр файлов

Update pytorch_nn.py

владелец 144b1aff
......@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def simple_gradient():
# print the gradient of 2x^2 + 5x
......@@ -18,14 +18,14 @@ def simple_gradient():
def create_nn(batch_size=200, learning_rate=0.01, epochs=10,
log_interval=10):
train_loader = torch.utils.data.DataLoader(
train_loader = DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
test_loader = DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
......@@ -92,4 +92,4 @@ if __name__ == "__main__":
if run_opt == 1:
simple_gradient()
elif run_opt == 2:
create_nn()
\ Нет новой строки в конце файла
create_nn()
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать